From 60e7ee0611745f07f1a90d32e6253f9c4161eaee Mon Sep 17 00:00:00 2001 From: chengduoZH <zhaochengduo@163.com> Date: Wed, 28 Feb 2018 15:38:01 +0800 Subject: [PATCH 01/40] refine concat_op --- paddle/fluid/operators/CMakeLists.txt | 1 + paddle/fluid/operators/concat_op.cc | 9 +- paddle/fluid/operators/concat_op.h | 53 +---- paddle/fluid/operators/math/CMakeLists.txt | 3 + paddle/fluid/operators/math/concat.cc | 89 +++++++ paddle/fluid/operators/math/concat.cu | 154 ++++++++++++ paddle/fluid/operators/math/concat.h | 37 +++ paddle/fluid/operators/math/concat_test.cc | 262 +++++++++++++++++++++ 8 files changed, 559 insertions(+), 49 deletions(-) create mode 100644 paddle/fluid/operators/math/concat.cc create mode 100644 paddle/fluid/operators/math/concat.cu create mode 100644 paddle/fluid/operators/math/concat.h create mode 100644 paddle/fluid/operators/math/concat_test.cc diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 4da46e94c5..266303b4cb 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -184,6 +184,7 @@ op_library(save_op DEPS lod_tensor) op_library(load_op DEPS lod_tensor) op_library(save_combine_op DEPS lod_tensor) op_library(load_combine_op DEPS lod_tensor) +op_library(concat_op DEPS concat_functor) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index bdce8f0a6f..0eedd8ee51 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -100,7 +100,8 @@ class ConcatOpGrad : public framework::OperatorWithKernel { namespace ops = paddle::operators; REGISTER_OP_EX(concat, ops::ConcatOp, ops::ConcatOpMaker, concat_grad, ops::ConcatOpGrad, false) -REGISTER_OP_CPU_KERNEL(concat, - ops::ConcatKernel<paddle::platform::CPUPlace, float>) -REGISTER_OP_CPU_KERNEL(concat_grad, - ops::ConcatGradKernel<paddle::platform::CPUPlace, float>) +REGISTER_OP_CPU_KERNEL( + concat, ops::ConcatKernel<paddle::platform::CPUDeviceContext, float>) +REGISTER_OP_CPU_KERNEL( + concat_grad, + ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, float>) diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h index 208a4481c6..19d877dfb6 100644 --- a/paddle/fluid/operators/concat_op.h +++ b/paddle/fluid/operators/concat_op.h @@ -17,6 +17,7 @@ limitations under the License. */ #include <utility> #include <vector> #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/concat.h" #include "paddle/fluid/operators/strided_memcpy.h" namespace paddle { @@ -27,55 +28,17 @@ class ConcatKernel : public framework::OpKernel<T> { public: void Compute(const framework::ExecutionContext& ctx) const override { auto ins = ctx.MultiInput<framework::Tensor>("X"); - auto* out = ctx.Output<framework::Tensor>("Out"); + framework::Tensor* out = ctx.Output<framework::Tensor>("Out"); int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis")); auto place = ctx.GetPlace(); out->mutable_data<T>(place); - - auto out_stride = framework::stride_numel(out->dims()); - - size_t output_offset = 0; - - // If axis >=1, copy to out immediately need to call many times - // of cuda memcpy. Copy the input to cpu and do the stride copy, - // then copy to gpu output. - - if (platform::is_gpu_place(place) && axis >= 1) { - platform::CPUPlace copy_place; - auto& cpu_ctx = *platform::DeviceContextPool::Instance().Get(copy_place); - framework::Tensor cpu_out; - cpu_out.Resize(out->dims()); - cpu_out.mutable_data<T>(copy_place); - auto& dev_ctx = ctx.device_context(); - std::vector<std::unique_ptr<framework::Tensor>> cpu_ins; - for (auto* in : ins) { - std::unique_ptr<framework::Tensor> cpu_in(new framework::Tensor); - framework::TensorCopy(*in, copy_place, dev_ctx, cpu_in.get()); - cpu_ins.emplace_back(std::move(cpu_in)); - } - // TODO(dzhwinter): overlap copy and compute stream - // https://devblogs.nvidia.com/how-overlap-data-transfers-cuda-cc/ - dev_ctx.Wait(); - - for (auto& in : cpu_ins) { - auto& cpu_in = *in.get(); - auto in_stride = framework::stride_numel(cpu_in.dims()); - - StridedNumelCopyWithAxis<T>( - cpu_ctx, axis, cpu_out.data<T>() + output_offset, out_stride, - cpu_in.data<T>(), in_stride, in_stride[axis]); - output_offset += in_stride[axis]; - } - framework::TensorCopy(cpu_out, place, dev_ctx, out); - } else { - for (auto* in : ins) { - auto in_stride = framework::stride_numel(in->dims()); - StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, - out->data<T>() + output_offset, out_stride, - in->data<T>(), in_stride, in_stride[axis]); - output_offset += in_stride[axis]; - } + std::vector<framework::Tensor> inputs(ins.size()); + for (size_t j = 0; j < ins.size(); ++j) { + inputs[j] = *ins[j]; } + auto& dev_ctx = ctx.template device_context<DeviceContext>(); + paddle::operators::math::ConcatFunctor<DeviceContext, T> concat_functor; + concat_functor(dev_ctx, inputs, static_cast<int>(axis), out); } }; diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 768106fadf..751e69b1c8 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -20,6 +20,7 @@ if(WITH_GPU) nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context) nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function) nv_library(cos_sim_functor SRCS cos_sim_functor.cc cos_sim_functor.cu DEPS device_context) + nv_library(concat_functor SRCS concat.cc concat.cu DEPS device_context tensor) else() cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto) cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function) @@ -37,6 +38,7 @@ else() cc_library(unpooling SRCS unpooling.cc DEPS device_context) cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function) cc_library(cos_sim_functor SRCS cos_sim_functor.cc DEPS device_context) + cc_library(concat_functor SRCS concat.cc DEPS device_context tensor) endif() cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) @@ -44,3 +46,4 @@ cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selec cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor) cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor) cc_test(sequence_padding_test SRCS sequence_padding_test.cc DEPS sequence_padding) +cc_test(concat_test SRCS concat_test.cc DEPS concat_functor tensor) diff --git a/paddle/fluid/operators/math/concat.cc b/paddle/fluid/operators/math/concat.cc new file mode 100644 index 0000000000..32059aa2f0 --- /dev/null +++ b/paddle/fluid/operators/math/concat.cc @@ -0,0 +1,89 @@ +/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/concat.h" + +namespace paddle { +namespace operators { +namespace math { + +/* + * All tensors' dimension should be the same. + */ +template <typename T> +class ConcatFunctor<platform::CPUDeviceContext, T> { + public: + void operator()(const platform::CPUDeviceContext& context, + std::vector<framework::Tensor>& input, const int axis, + framework::Tensor* output) { + // assume the the max size of input is less than 8 and see the performance + // save origin dim + int num = input.size(); + std::vector<paddle::framework::DDim> origin_dim(num); + // for (int j = 0; j < num; ++j) { + // origin_dim[j] = input[j].dims(); + // } + auto out_dim = output->dims(); + + // get the matrix size + int rows = 1; + auto dim_0 = input[0].dims(); + for (int i = 0; i < axis; ++i) { + rows *= dim_0[i]; + } + int cols = input[0].numel() / rows; + int out_rows = rows, out_cols = 0; + bool sameShape = true; + + // reshape to matrix + for (int i = 0; i < num; ++i) { + int t_cols = input[i].numel() / rows; + if (sameShape) { + if (t_cols != cols) sameShape = false; + } + out_cols += t_cols; + input[i].Resize({rows, t_cols}); + } + output->Resize({out_rows, out_cols}); + auto& cpu_place = boost::get<platform::CPUPlace>(context.GetPlace()); + // computation + for (int k = 0; k < rows; ++k) { + // offset k * out_cols + T* dst_ptr = output->data<T>() + k * out_cols; + int col_idx = 0; + for (int j = 0; j < num; ++j) { + int col_len = input[j].dims()[1]; + const T* src_prt = input[j].data<T>() + k * col_len; + memory::Copy(cpu_place, dst_ptr + col_idx, cpu_place, src_prt, + sizeof(T) * col_len); + col_idx += col_len; + } + } + + // recover origin dim + // for (int j = 0; j < num; ++j) { + // input[j]->Resize(origin_dim[j]); + // } + output->Resize(out_dim); + } +}; + +template class ConcatFunctor<platform::CPUDeviceContext, int>; +template class ConcatFunctor<platform::CPUDeviceContext, int64_t>; +template class ConcatFunctor<platform::CPUDeviceContext, float>; +template class ConcatFunctor<platform::CPUDeviceContext, double>; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/concat.cu b/paddle/fluid/operators/math/concat.cu new file mode 100644 index 0000000000..6932e22f84 --- /dev/null +++ b/paddle/fluid/operators/math/concat.cu @@ -0,0 +1,154 @@ +/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/concat.h" +#include "paddle/fluid/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math { + +// TODO(zcd): This can be replaced by tensor, +// if that, maybe we should add int8 to VarType::Type. +// Or replaced by tensorArray. +static constexpr int MaxSize = 32; +template <typename T> +struct CUDADeviceArray { + T data[MaxSize]; + int size; +}; + +template <typename T> +__device__ T upper_bound(const T* first, T count, T val) { + const T* orig = first; + const T* it = nullptr; + T step = 0; + while (count > 0) { + it = first; + step = count / 2; + it += step; + if (!(val < *it)) { + first = ++it; + count -= step + 1; + } else { + count = step; + } + } + return first - orig; +} + +template <typename T> +__global__ void KernelConcat(const CUDADeviceArray<const T*> inputs, + const CUDADeviceArray<int> input_cols, + const int output_rows, const int output_cols, + T* output) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + int segment = upper_bound<int>(input_cols.data, input_cols.size, tid_x) - 1; + + int curr_offset = input_cols.data[segment]; + int curr_segment = segment; + for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { + T curr_col_offset; + while ((curr_col_offset = input_cols.data[curr_segment + 1]) <= tid_x) { + curr_offset = curr_col_offset; + ++curr_segment; + } + + int local_col = tid_x - curr_offset; + int segment_width = curr_col_offset - curr_offset; + const T* input_ptr = inputs.data[curr_segment]; + + for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) + output[tid_y * output_cols + tid_x] = + input_ptr[tid_y * segment_width + local_col]; + } +} + +/* + * All tensors' dimension should be the same. + */ +template <typename T> +class ConcatFunctor<platform::CUDADeviceContext, T> { + public: + void operator()(const platform::CUDADeviceContext& context, + std::vector<framework::Tensor>& input, const int axis, + framework::Tensor* output) { + // assume the the max size of input is less than 8 and see the performance + // save origin dim + int num = input.size(); + // std::vector<paddle::framework::DDim> origin_dim(num); + // for (int j = 0; j < num; ++j) { + // origin_dim[j] = input[j].dims(); + // } + auto out_dim = output->dims(); + + // get the matrix size + int rows = 1; + auto dim_0 = input[0].dims(); + for (int i = 0; i < axis; ++i) { + rows *= dim_0[i]; + } + int cols = input[0].numel() / rows; + int out_rows = rows, out_cols = 0; + bool sameShape = true; + + CUDADeviceArray<const T*> inputs_data; + CUDADeviceArray<int> inputs_cols; + inputs_data.size = num; + inputs_cols.size = num + 1; + inputs_cols.data[0] = 0; + // reshape to matrix + // check input shape is valid + for (int i = 0; i < num; ++i) { + int t_cols = input[i].numel() / rows; + if (sameShape) { + if (t_cols != cols) sameShape = false; + } + out_cols += t_cols; + input[i].Resize({rows, t_cols}); + inputs_cols.data[i + 1] = out_cols; + inputs_data.data[i] = input[i].data<T>(); + } + output->Resize({out_rows, out_cols}); + + // computation + const int kThreadsPerBlock = 256; + int block_cols = std::min(out_cols, kThreadsPerBlock); + int block_rows = std::max(kThreadsPerBlock / block_cols, 1); + dim3 block_size = dim3(block_cols, block_rows, 1); + + int grid_cols = (out_cols + block_cols - 1) / block_cols; + int grid_rows = (out_rows + block_rows - 1) / block_rows; + dim3 grid_size = dim3(grid_cols, grid_rows, 1); + + KernelConcat<<<grid_size, block_size, 0, context.stream()>>>( + inputs_data, inputs_cols, out_rows, out_cols, output->data<T>()); + + // recover origin dim + // for (int j = 0; j < num; ++j) { + // input[j].Resize(origin_dim[j]); + // } + output->Resize(out_dim); + } +}; + +template class ConcatFunctor<platform::CUDADeviceContext, int>; +template class ConcatFunctor<platform::CUDADeviceContext, int64_t>; +template class ConcatFunctor<platform::CUDADeviceContext, float>; +template class ConcatFunctor<platform::CUDADeviceContext, double>; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/concat.h b/paddle/fluid/operators/math/concat.h new file mode 100644 index 0000000000..50c75dd208 --- /dev/null +++ b/paddle/fluid/operators/math/concat.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/tensor.h" + +namespace paddle { +namespace operators { +namespace math { + +/* + * the tensor's shape of input will be changed, + * so the second parameter is not const. + * + */ +template <typename DeviceContext, typename T> +class ConcatFunctor { + public: + void operator()(const DeviceContext& context, + std::vector<framework::Tensor>& input, const int axis, + framework::Tensor* output); +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/concat_test.cc b/paddle/fluid/operators/math/concat_test.cc new file mode 100644 index 0000000000..815070b113 --- /dev/null +++ b/paddle/fluid/operators/math/concat_test.cc @@ -0,0 +1,262 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/concat.h" +#include <gtest/gtest.h> +#include <vector> +#include "paddle/fluid/framework/tensor_util.h" + +using namespace paddle::framework; +using namespace paddle::platform; + +template <typename DeviceContext, typename Place> +void testConcat() { + Tensor input_a_cpu; + Tensor input_b_cpu; + Tensor out_cpu; + Tensor input_a; + Tensor input_b; + Tensor out; + + DeviceContext* context = new DeviceContext(Place()); + // DeviceContext context(Place()); + + /** + * cast1: + * inputs: + * t_a.shape: [2, 3, 4] + * t_b.shape: [3, 3, 4] + * output: + * out.shape: [5, 3, 4] + */ + auto dim_a = make_ddim({2, 3, 4}); + auto dim_b = make_ddim({3, 3, 4}); + auto dim_out = make_ddim({5, 3, 4}); + + input_a.mutable_data<int>(dim_a, Place()); + input_b.mutable_data<int>(dim_b, Place()); + out.mutable_data<int>(dim_out, Place()); + + if (paddle::platform::is_gpu_place(Place())) { + input_a_cpu.mutable_data<int>(dim_a, CPUPlace()); + input_b_cpu.mutable_data<int>(dim_b, CPUPlace()); + out_cpu.mutable_data<int>(dim_out, CPUPlace()); + } + + int* a_ptr; + int* b_ptr; + if (paddle::platform::is_gpu_place(Place())) { + a_ptr = input_a_cpu.data<int>(); + b_ptr = input_b_cpu.data<int>(); + } else { + a_ptr = input_a.data<int>(); + b_ptr = input_b.data<int>(); + } + + for (int i = 0; i < 2 * 3 * 4; ++i) { + a_ptr[i] = i; + } + for (int i = 0; i < 3 * 3 * 4; ++i) { + b_ptr[i] = i; + } + + if (paddle::platform::is_gpu_place(Place())) { + TensorCopy(input_a_cpu, Place(), *context, &input_a); + TensorCopy(input_b_cpu, Place(), *context, &input_b); + } + + std::vector<Tensor> input; + input.push_back(input_a); + input.push_back(input_b); + + paddle::operators::math::ConcatFunctor<DeviceContext, int> concat_functor; + concat_functor(*context, input, 0, &out); + + // check the dim of input_a, input_b + PADDLE_ENFORCE_EQ(input_a.dims(), dim_a); + PADDLE_ENFORCE_EQ(input_b.dims(), dim_b); + + int* out_ptr; + if (paddle::platform::is_gpu_place(Place())) { + TensorCopy(out, CPUPlace(), *context, &out_cpu); + out_ptr = out_cpu.data<int>(); + } else { + out_ptr = out.data<int>(); + } + + int cols = 2 * 3 * 4; + int idx_a = 0, idx_b = 0; + for (int j = 0; j < 5 * 3 * 4; ++j) { + if (j >= cols) { + PADDLE_ENFORCE_EQ(out_ptr[j], b_ptr[idx_b]); + ++idx_b; + } else { + PADDLE_ENFORCE_EQ(out_ptr[j], a_ptr[idx_a]); + ++idx_a; + } + } + // + /** + * cast2: + * inputs: + * t_a.shape: [2, 3, 4] + * t_b.shape: [2, 4, 4] + * output: + * out.shape: [2, 7, 4] + */ + dim_a = make_ddim({2, 3, 4}); + dim_b = make_ddim({2, 4, 4}); + dim_out = make_ddim({2, 7, 4}); + + input_a.Resize(dim_a); + input_b.Resize(dim_b); + out.Resize(dim_out); + if (paddle::platform::is_gpu_place(Place())) { + input_a_cpu.Resize(dim_a); + input_b_cpu.Resize(dim_b); + out_cpu.Resize(dim_out); + } + + if (paddle::platform::is_gpu_place(Place())) { + a_ptr = input_a_cpu.data<int>(); + b_ptr = input_b_cpu.data<int>(); + } else { + a_ptr = input_a.data<int>(); + b_ptr = input_b.data<int>(); + } + + for (int i = 0; i < 2 * 3 * 4; ++i) { + a_ptr[i] = i; + } + for (int i = 0; i < 2 * 4 * 4; ++i) { + b_ptr[i] = i; + } + + if (paddle::platform::is_gpu_place(Place())) { + TensorCopy(input_a_cpu, Place(), *context, &input_a); + TensorCopy(input_b_cpu, Place(), *context, &input_b); + } + + input.clear(); + input.push_back(input_a); + input.push_back(input_b); + + concat_functor(*context, input, 1, &out); + + // check the dim of input_a, input_b + PADDLE_ENFORCE_EQ(input_a.dims(), dim_a); + PADDLE_ENFORCE_EQ(input_b.dims(), dim_b); + + if (paddle::platform::is_gpu_place(Place())) { + TensorCopy(out, CPUPlace(), *context, &out_cpu); + out_ptr = out_cpu.data<int>(); + } else { + out_ptr = out.data<int>(); + } + + cols = 3 * 4; + idx_a = 0, idx_b = 0; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 28; ++j) { + if (j >= cols) { + PADDLE_ENFORCE_EQ(out_ptr[i * 28 + j], b_ptr[idx_b]); + ++idx_b; + } else { + PADDLE_ENFORCE_EQ(out_ptr[i * 28 + j], a_ptr[idx_a]); + ++idx_a; + } + } + } + + /** + * cast3: + * inputs: + * t_a.shape: [2, 3, 5] + * t_b.shape: [2, 3, 4] + * output: + * out.shape: [2, 3, 9] + */ + dim_a = make_ddim({2, 3, 4}); + dim_b = make_ddim({2, 3, 5}); + dim_out = make_ddim({2, 3, 9}); + + input_a.Resize(dim_a); + input_b.Resize(dim_b); + out.Resize(dim_out); + if (paddle::platform::is_gpu_place(Place())) { + input_a_cpu.Resize(dim_a); + input_b_cpu.Resize(dim_b); + out_cpu.Resize(dim_out); + } + + if (paddle::platform::is_gpu_place(Place())) { + a_ptr = input_a_cpu.data<int>(); + b_ptr = input_b_cpu.data<int>(); + } else { + a_ptr = input_a.data<int>(); + b_ptr = input_b.data<int>(); + } + + for (int i = 0; i < 2 * 3 * 4; ++i) { + a_ptr[i] = i; + } + for (int i = 0; i < 2 * 3 * 5; ++i) { + b_ptr[i] = i; + } + + if (paddle::platform::is_gpu_place(Place())) { + TensorCopy(input_a_cpu, Place(), *context, &input_a); + TensorCopy(input_b_cpu, Place(), *context, &input_b); + } + + input.clear(); + input.push_back(input_a); + input.push_back(input_b); + + concat_functor(*context, input, 2, &out); + + // check the dim of input_a, input_b + PADDLE_ENFORCE_EQ(input_a.dims(), dim_a); + PADDLE_ENFORCE_EQ(input_b.dims(), dim_b); + + if (paddle::platform::is_gpu_place(Place())) { + TensorCopy(out, CPUPlace(), *context, &out_cpu); + out_ptr = out_cpu.data<int>(); + } else { + out_ptr = out.data<int>(); + } + + // check the data + cols = 4; + idx_a = 0, idx_b = 0; + for (int i = 0; i < 6; ++i) { + for (int j = 0; j < 9; ++j) { + if (j >= cols) { + PADDLE_ENFORCE_EQ(out_ptr[i * 9 + j], b_ptr[idx_b]); + ++idx_b; + } else { + PADDLE_ENFORCE_EQ(out_ptr[i * 9 + j], a_ptr[idx_a]); + ++idx_a; + } + } + } +} + +TEST(math, concat) { + testConcat<paddle::platform::CPUDeviceContext, paddle::platform::CPUPlace>(); +#ifdef PADDLE_WITH_CUDA + testConcat<paddle::platform::CUDADeviceContext, + paddle::platform::CUDAPlace>(); +#endif +} From f67275a920f5dc7822a240852588fd6f5f4777d5 Mon Sep 17 00:00:00 2001 From: Luo Tao <luotao02@baidu.com> Date: Thu, 1 Mar 2018 17:34:25 +0800 Subject: [PATCH 02/40] refine operator/math/CMakeLists.txt, seperate im2col from math_function --- paddle/fluid/operators/CMakeLists.txt | 6 +- paddle/fluid/operators/math/CMakeLists.txt | 87 ++++++++++++---------- 2 files changed, 52 insertions(+), 41 deletions(-) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 4da46e94c5..9f6756541e 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -173,11 +173,11 @@ op_library(parallel_do_op DEPS executor) op_library(create_reader_op DEPS reader) if (WITH_GPU) - op_library(conv_op DEPS vol2col depthwise_conv) + op_library(conv_op DEPS vol2col depthwise_conv im2col) else() - op_library(conv_op DEPS vol2col) + op_library(conv_op DEPS vol2col im2col) endif() -op_library(conv_transpose_op DEPS vol2col) +op_library(conv_transpose_op DEPS vol2col im2col) # FIXME(typhoonzero): save/load depends lodtensor serialization functions op_library(save_op DEPS lod_tensor) diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 768106fadf..49219d97af 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -1,46 +1,57 @@ add_subdirectory(detail) +function(math_library TARGET) + # math_library is a function to create math library. + # The interface is the same as cc_library. + # But it handle split GPU/CPU code and link some common library. + set(cc_srcs) + set(cu_srcs) + set(math_common_deps device_context framework_proto) + set(multiValueArgs SRCS DEPS) + cmake_parse_arguments(math_library "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN}) + + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc) + list(APPEND cc_srcs ${TARGET}.cc) + endif() + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu) + list(APPEND cu_srcs ${TARGET}.cu) + endif() + + if (WITH_GPU) + nv_library(${TARGET} SRCS ${cc_srcs} ${cu_srcs} DEPS ${math_library_DEPS} ${math_common_deps}) + else() + cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${math_library_DEPS} ${math_common_deps}) + endif() +endfunction() + +math_library(math_function DEPS cblas) +math_library(im2col) +math_library(selected_rows_functor DEPS selected_rows) +math_library(softmax) +math_library(cross_entropy) +math_library(pooling) +math_library(sequence_pooling) +math_library(vol2col) +math_library(context_project) +math_library(sequence2batch) +math_library(sequence_padding) +math_library(sequence_scale) +math_library(maxouting) +math_library(unpooling) +math_library(cos_sim_functor) +math_library(lstm_compute DEPS activation_functions) +math_library(gru_compute DEPS activation_functions) if(WITH_GPU) - nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context framework_proto) - nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function tensor) - nv_library(selected_rows_functor SRCS selected_rows_functor.cc selected_rows_functor.cu DEPS selected_rows math_function) - nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor) - nv_library(softmax SRCS softmax.cc softmax.cu DEPS device_context) - nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS device_context) - nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context) nv_library(depthwise_conv SRCS depthwise_conv.cu DEPS device_context) - nv_library(sequence_pooling SRCS sequence_pooling.cc sequence_pooling.cu DEPS device_context math_function) - nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context tensor) - nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context math_function) - nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context tensor math_function) - nv_library(sequence_padding SRCS sequence_padding.cc sequence_padding.cu DEPS lod_tensor device_context) - nv_library(sequence_scale SRCS sequence_scale.cc sequence_scale.cu DEPS lod_tensor device_context) - nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions) - nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context) - nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context) - nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function) - nv_library(cos_sim_functor SRCS cos_sim_functor.cc cos_sim_functor.cu DEPS device_context) -else() - cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto) - cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function) - cc_library(softmax SRCS softmax.cc DEPS device_context) - cc_library(cross_entropy SRCS cross_entropy.cc DEPS device_context) - cc_library(pooling SRCS pooling.cc DEPS device_context) - cc_library(sequence_pooling SRCS sequence_pooling.cc DEPS device_context math_function) - cc_library(vol2col SRCS vol2col.cc DEPS device_context tensor) - cc_library(context_project SRCS context_project.cc DEPS device_context math_function) - cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context tensor math_function) - cc_library(sequence_padding SRCS sequence_padding.cc DEPS lod_tensor device_context) - cc_library(sequence_scale SRCS sequence_scale.cc DEPS lod_tensor device_context) - cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions) - cc_library(maxouting SRCS maxouting.cc DEPS device_context) - cc_library(unpooling SRCS unpooling.cc DEPS device_context) - cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function) - cc_library(cos_sim_functor SRCS cos_sim_functor.cc DEPS device_context) endif() -cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) +cc_test(math_function_test SRCS math_function_test.cc) cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor) -cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor) -cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor) +cc_test(im2col_test SRCS im2col_test.cc DEPS im2col) +cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col) cc_test(sequence_padding_test SRCS sequence_padding_test.cc DEPS sequence_padding) +if(WITH_GPU) + nv_test(math_function_gpu_test SRCS math_function_test.cu) + nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor) +endif() From baf70dc8f31cfbfc39cf83d8c5a45af3dba969be Mon Sep 17 00:00:00 2001 From: Yancey1989 <yancey1989@gmail.com> Date: Fri, 2 Mar 2018 16:37:54 +0800 Subject: [PATCH 03/40] Fix nccl version in manylinux --- tools/manylinux1/Dockerfile.x64 | 12 ++++-------- .../manylinux1/build_scripts/install_nccl2.sh | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 8 deletions(-) create mode 100644 tools/manylinux1/build_scripts/install_nccl2.sh diff --git a/tools/manylinux1/Dockerfile.x64 b/tools/manylinux1/Dockerfile.x64 index 93cab692e3..bca0b77ad7 100644 --- a/tools/manylinux1/Dockerfile.x64 +++ b/tools/manylinux1/Dockerfile.x64 @@ -13,8 +13,10 @@ ENV PATH /opt/rh/devtoolset-2/root/usr/bin:$PATH ENV LD_LIBRARY_PATH /opt/rh/devtoolset-2/root/usr/lib64:/opt/rh/devtoolset-2/root/usr/lib:/usr/local/lib64:/usr/local/lib:${LD_LIBRARY_PATH} ENV PKG_CONFIG_PATH=/usr/local/lib/pkgconfig +RUN yum install -y sqlite-devel zlib-devel openssl-devel pcre-devel vim tk-devel tkinter libtool xz COPY build_scripts /build_scripts -RUN bash build_scripts/build.sh && rm -r build_scripts +RUN bash build_scripts/build.sh && \ + bash build_scripts/install_nccl2.sh && rm -r build_scripts ENV SSL_CERT_FILE=/opt/_internal/certs.pem @@ -34,9 +36,6 @@ RUN cd /opt && wget -q --no-check-certificate https://github.com/google/protobuf tar xzf protobuf-cpp-3.1.0.tar.gz && \ cd protobuf-3.1.0 && ./configure && make -j4 && make install && cd .. && rm -f protobuf-cpp-3.1.0.tar.gz - -RUN yum install -y sqlite-devel zlib-devel openssl-devel pcre-devel vim tk-devel tkinter libtool - RUN wget -O /root/requirements.txt https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/python/requirements.txt RUN LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs4/lib:${LD_LIBRARY_PATH} /opt/python/cp27-cp27mu/bin/pip install -r /root/requirements.txt && \ @@ -47,10 +46,7 @@ RUN LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs4/lib:${LD_LIBRARY_PATH} /o RUN LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs4/lib:${LD_LIBRARY_PATH} /opt/python/cp27-cp27mu/bin/pip install pre-commit 'ipython==5.3.0' opencv-python && \ LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs2/lib:${LD_LIBRARY_PATH} /opt/python/cp27-cp27m/bin/pip install pre-commit 'ipython==5.3.0' opencv-python -RUN wget -O /opt/swig-2.0.12.tar.gz https://sourceforge.net/projects/swig/files/swig/swig-2.0.12/swig-2.0.12.tar.gz/download && \ +RUN wget -O /opt/swig-2.0.12.tar.gz https://cytranet.dl.sourceforge.net/project/swig/swig/swig-2.0.12/swig-2.0.12.tar.gz && \ cd /opt && tar xzf swig-2.0.12.tar.gz && cd /opt/swig-2.0.12 && ./configure && make && make install && cd /opt && rm swig-2.0.12.tar.gz -RUN mkdir -p /src && cd /src && git clone https://github.com/NVIDIA/nccl.git nccl && cd nccl &&\ - make -j `nproc` install <NCCL_MAKE_OPTS> && cd .. && rm -rf nccl - CMD ["bash", "/paddle/paddle/scripts/docker/build.sh"] diff --git a/tools/manylinux1/build_scripts/install_nccl2.sh b/tools/manylinux1/build_scripts/install_nccl2.sh new file mode 100644 index 0000000000..7efc1fe865 --- /dev/null +++ b/tools/manylinux1/build_scripts/install_nccl2.sh @@ -0,0 +1,18 @@ +#!/bin/bash +DEB="nccl-repo-ubuntu1604-2.1.4-ga-cuda8.0_1-1_amd64.deb" +DIR="/nccl2" +mkdir -p $DIR +# we cached the nccl2 deb package in BOS, so we can download it with wget +# install nccl2: http://docs.nvidia.com/deeplearning/sdk/nccl-install-guide/index.html#down +wget -O $DIR/$DEB \ + "http://nccl2-deb.gz.bcebos.com/nccl-repo-ubuntu1604-2.1.4-ga-cuda8.0_1-1_amd64.deb?responseContentDisposition=attachment" + +cd $DIR && ar x $DEB && tar xf data.tar.xz +DEBS=$(find ./var/ -name "*.deb") +for sub_deb in $DEBS; do + echo $sub_deb + ar x $sub_deb && tar xf data.tar.xz +done +mv -f usr/include/nccl.h /usr/local/include/ +mv -f usr/lib/libnccl* /usr/local/lib/ +rm -rf $DIR From 00e596edbeeb1d5a7f1c4f2608e161a814e59a14 Mon Sep 17 00:00:00 2001 From: chengduoZH <zhaochengduo@163.com> Date: Fri, 2 Mar 2018 11:15:32 +0800 Subject: [PATCH 04/40] get max threads of GPU --- paddle/fluid/operators/concat_op.h | 19 +-- paddle/fluid/operators/math/concat.cc | 75 ++++++--- paddle/fluid/operators/math/concat.cu | 170 ++++++++++++++++++--- paddle/fluid/operators/math/concat.h | 11 +- paddle/fluid/operators/math/concat_test.cc | 74 +++++++++ paddle/fluid/platform/gpu_info.cc | 20 +++ paddle/fluid/platform/gpu_info.h | 6 + 7 files changed, 320 insertions(+), 55 deletions(-) diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h index 19d877dfb6..a65b1987cb 100644 --- a/paddle/fluid/operators/concat_op.h +++ b/paddle/fluid/operators/concat_op.h @@ -32,6 +32,7 @@ class ConcatKernel : public framework::OpKernel<T> { int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis")); auto place = ctx.GetPlace(); out->mutable_data<T>(place); + std::vector<framework::Tensor> inputs(ins.size()); for (size_t j = 0; j < ins.size(); ++j) { inputs[j] = *ins[j]; @@ -49,17 +50,17 @@ class ConcatGradKernel : public framework::OpKernel<T> { auto* in = ctx.Input<framework::Tensor>(framework::GradVarName("Out")); auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X")); int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis")); - size_t input_offset = 0; - auto in_stride = framework::stride_numel(in->dims()); - for (auto& out : outs) { - out->mutable_data<T>(ctx.GetPlace()); - auto out_stride = framework::stride_numel(out->dims()); - StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(), - out_stride, in->data<T>() + input_offset, - in_stride, out_stride[axis]); - input_offset += out_stride[axis]; + std::vector<framework::Tensor> outputs(outs.size()); + for (size_t j = 0; j < outs.size(); ++j) { + outs[j]->mutable_data<T>(ctx.GetPlace()); + outputs[j] = *outs[j]; } + + auto& dev_ctx = ctx.template device_context<DeviceContext>(); + paddle::operators::math::ConcatGradFunctor<DeviceContext, T> + concat_grad_functor; + concat_grad_functor(dev_ctx, *in, static_cast<int>(axis), outputs); } }; diff --git a/paddle/fluid/operators/math/concat.cc b/paddle/fluid/operators/math/concat.cc index 32059aa2f0..5c5c6489d6 100644 --- a/paddle/fluid/operators/math/concat.cc +++ b/paddle/fluid/operators/math/concat.cc @@ -25,16 +25,12 @@ template <typename T> class ConcatFunctor<platform::CPUDeviceContext, T> { public: void operator()(const platform::CPUDeviceContext& context, - std::vector<framework::Tensor>& input, const int axis, + const std::vector<framework::Tensor>& input, const int axis, framework::Tensor* output) { // assume the the max size of input is less than 8 and see the performance // save origin dim int num = input.size(); std::vector<paddle::framework::DDim> origin_dim(num); - // for (int j = 0; j < num; ++j) { - // origin_dim[j] = input[j].dims(); - // } - auto out_dim = output->dims(); // get the matrix size int rows = 1; @@ -42,40 +38,72 @@ class ConcatFunctor<platform::CPUDeviceContext, T> { for (int i = 0; i < axis; ++i) { rows *= dim_0[i]; } - int cols = input[0].numel() / rows; int out_rows = rows, out_cols = 0; - bool sameShape = true; - // reshape to matrix + // get input's cols + std::vector<int64_t> input_cols(input.size()); for (int i = 0; i < num; ++i) { int t_cols = input[i].numel() / rows; - if (sameShape) { - if (t_cols != cols) sameShape = false; - } out_cols += t_cols; - input[i].Resize({rows, t_cols}); + input_cols[i] = t_cols; } - output->Resize({out_rows, out_cols}); auto& cpu_place = boost::get<platform::CPUPlace>(context.GetPlace()); + // computation - for (int k = 0; k < rows; ++k) { - // offset k * out_cols + for (int k = 0; k < out_rows; ++k) { T* dst_ptr = output->data<T>() + k * out_cols; int col_idx = 0; for (int j = 0; j < num; ++j) { - int col_len = input[j].dims()[1]; + int col_len = input_cols[j]; const T* src_prt = input[j].data<T>() + k * col_len; memory::Copy(cpu_place, dst_ptr + col_idx, cpu_place, src_prt, sizeof(T) * col_len); col_idx += col_len; } } + } +}; + +template <typename T> +class ConcatGradFunctor<platform::CPUDeviceContext, T> { + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& input, const int axis, + std::vector<framework::Tensor>& outputs) { + // assume the the max size of input is less than 8 and see the performance + // save origin dim + int num = outputs.size(); + std::vector<paddle::framework::DDim> origin_dim(num); - // recover origin dim - // for (int j = 0; j < num; ++j) { - // input[j]->Resize(origin_dim[j]); - // } - output->Resize(out_dim); + // get the matrix size + int input_rows = 1; + auto dim_0 = outputs[0].dims(); + for (int i = 0; i < axis; ++i) { + input_rows *= dim_0[i]; + } + int input_cols = 0; + + // get outputs' cols + std::vector<int64_t> output_cols(outputs.size()); + for (int i = 0; i < num; ++i) { + int t_cols = outputs[i].numel() / input_rows; + input_cols += t_cols; + output_cols[i] = t_cols; + } + auto& cpu_place = boost::get<platform::CPUPlace>(context.GetPlace()); + + // computation + for (int k = 0; k < input_rows; ++k) { + const T* src_ptr = input.data<T>() + k * input_cols; + int col_idx = 0; + for (int j = 0; j < num; ++j) { + int col_len = output_cols[j]; + T* dst_ptr = outputs[j].data<T>() + k * col_len; + memory::Copy(cpu_place, dst_ptr, cpu_place, src_ptr + col_idx, + sizeof(T) * col_len); + col_idx += col_len; + } + } } }; @@ -84,6 +112,11 @@ template class ConcatFunctor<platform::CPUDeviceContext, int64_t>; template class ConcatFunctor<platform::CPUDeviceContext, float>; template class ConcatFunctor<platform::CPUDeviceContext, double>; +template class ConcatGradFunctor<platform::CPUDeviceContext, int>; +template class ConcatGradFunctor<platform::CPUDeviceContext, int64_t>; +template class ConcatGradFunctor<platform::CPUDeviceContext, float>; +template class ConcatGradFunctor<platform::CPUDeviceContext, double>; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/concat.cu b/paddle/fluid/operators/math/concat.cu index 6932e22f84..8af7233426 100644 --- a/paddle/fluid/operators/math/concat.cu +++ b/paddle/fluid/operators/math/concat.cu @@ -22,7 +22,7 @@ namespace math { // TODO(zcd): This can be replaced by tensor, // if that, maybe we should add int8 to VarType::Type. // Or replaced by tensorArray. -static constexpr int MaxSize = 32; +static constexpr int MaxSize = 8; template <typename T> struct CUDADeviceArray { T data[MaxSize]; @@ -54,7 +54,6 @@ __global__ void KernelConcat(const CUDADeviceArray<const T*> inputs, const int output_rows, const int output_cols, T* output) { int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - int tid_y = blockIdx.y * blockDim.y + threadIdx.y; int segment = upper_bound<int>(input_cols.data, input_cols.size, tid_x) - 1; int curr_offset = input_cols.data[segment]; @@ -69,13 +68,73 @@ __global__ void KernelConcat(const CUDADeviceArray<const T*> inputs, int local_col = tid_x - curr_offset; int segment_width = curr_col_offset - curr_offset; const T* input_ptr = inputs.data[curr_segment]; - + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) output[tid_y * output_cols + tid_x] = input_ptr[tid_y * segment_width + local_col]; } } +template <typename T> +__global__ void KernelConcat(const CUDADeviceArray<const T*> inputs, + const int input_col, const int output_rows, + const int output_cols, T* output) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + float inv_input_col = 1.0 / input_col; + for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { + int split = tid_x * inv_input_col; + int in_offset = tid_x - split * input_col; + const T* input_ptr = inputs.data[split]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) + output[tid_y * output_cols + tid_x] = + input_ptr[tid_y * input_col + in_offset]; + } +} + +template <typename T> +__global__ void KernelConcatGrad(const T* input, const int input_row, + const int input_col, + CUDADeviceArray<int> output_cols, + CUDADeviceArray<T*> outputs) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + int segment = upper_bound<int>(output_cols.data, output_cols.size, tid_x) - 1; + int curr_offset = output_cols.data[segment]; + int curr_segment = segment; + for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) { + T curr_col_offset; + while ((curr_col_offset = output_cols.data[curr_segment + 1]) <= tid_x) { + curr_offset = curr_col_offset; + ++curr_segment; + } + + int local_col = tid_x - curr_offset; + int segment_width = curr_col_offset - curr_offset; + T* output_ptr = outputs.data[curr_segment]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y) + output_ptr[tid_y * segment_width + local_col] = + input[tid_y * input_col + tid_x]; + } +} + +template <typename T> +__global__ void KernelConcatGrad(const T* input, const int input_row, + const int input_col, const int output_cols, + CUDADeviceArray<T*> outputs) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + float inv_input_col = 1.0 / input_col; + for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) { + int split = tid_x * inv_input_col; + int in_offset = tid_x - split * input_col; + T* output_ptr = outputs.data[split]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y) + output_ptr[tid_y * output_cols + in_offset] = + input[tid_y * input_col + tid_x]; + } +} + /* * All tensors' dimension should be the same. */ @@ -83,17 +142,13 @@ template <typename T> class ConcatFunctor<platform::CUDADeviceContext, T> { public: void operator()(const platform::CUDADeviceContext& context, - std::vector<framework::Tensor>& input, const int axis, + const std::vector<framework::Tensor>& input, const int axis, framework::Tensor* output) { // assume the the max size of input is less than 8 and see the performance // save origin dim int num = input.size(); - // std::vector<paddle::framework::DDim> origin_dim(num); - // for (int j = 0; j < num; ++j) { - // origin_dim[j] = input[j].dims(); - // } - auto out_dim = output->dims(); - + PADDLE_ENFORCE_LT(num, MaxSize, "input number should be less than %d", + MaxSize); // get the matrix size int rows = 1; auto dim_0 = input[0].dims(); @@ -117,30 +172,96 @@ class ConcatFunctor<platform::CUDADeviceContext, T> { if (t_cols != cols) sameShape = false; } out_cols += t_cols; - input[i].Resize({rows, t_cols}); inputs_cols.data[i + 1] = out_cols; inputs_data.data[i] = input[i].data<T>(); } - output->Resize({out_rows, out_cols}); // computation - const int kThreadsPerBlock = 256; + // set the thread block and grid according to CurrentDeviceId + const int kThreadsPerBlock = 1024; int block_cols = std::min(out_cols, kThreadsPerBlock); int block_rows = std::max(kThreadsPerBlock / block_cols, 1); dim3 block_size = dim3(block_cols, block_rows, 1); - int grid_cols = (out_cols + block_cols - 1) / block_cols; - int grid_rows = (out_rows + block_rows - 1) / block_rows; + int dev_id = paddle::platform::GetCurrentDeviceId(); + int multi_process = paddle::platform::GetCUDAMultiProcessors(dev_id); + int max_threads_per_mp = + paddle::platform::GetCUDAMaxThreadsPerMultiProcessor(dev_id); + int max_threads = multi_process * max_threads_per_mp; + int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); + + int grid_cols = + std::min((out_cols + block_cols - 1) / block_cols, max_blocks); + int grid_rows = + std::min(max_blocks / grid_cols, std::max(out_rows / block_rows, 1)); dim3 grid_size = dim3(grid_cols, grid_rows, 1); - KernelConcat<<<grid_size, block_size, 0, context.stream()>>>( - inputs_data, inputs_cols, out_rows, out_cols, output->data<T>()); + if (sameShape) { + KernelConcat<<<grid_size, block_size, 0, context.stream()>>>( + inputs_data, cols, out_rows, out_cols, output->data<T>()); + } else { + KernelConcat<<<grid_size, block_size, 0, context.stream()>>>( + inputs_data, inputs_cols, out_rows, out_cols, output->data<T>()); + } + } +}; + +template <typename T> +class ConcatGradFunctor<platform::CUDADeviceContext, T> { + public: + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& input, const int axis, + std::vector<framework::Tensor>& outputs) { + // assume the the max size of input is less than 8 and see the performance + // save origin dim + int num = outputs.size(); + PADDLE_ENFORCE_LT(num, MaxSize, "input number should be less than %d", + MaxSize); + + // get the matrix size + int input_row = 1; + auto dim_0 = outputs[0].dims(); + for (int i = 0; i < axis; ++i) { + input_row *= dim_0[i]; + } + + int output_col_0 = outputs[0].numel() / input_row; + int input_col = 0; + bool sameShape = true; + + CUDADeviceArray<T*> outputs_data; + CUDADeviceArray<int> outputs_cols; + outputs_data.size = num; + outputs_cols.size = num + 1; + outputs_cols.data[0] = 0; - // recover origin dim - // for (int j = 0; j < num; ++j) { - // input[j].Resize(origin_dim[j]); - // } - output->Resize(out_dim); + for (int i = 0; i < num; ++i) { + int t_col = outputs[i].numel() / input_row; + if (sameShape) { + if (t_col != output_col_0) sameShape = false; + } + input_col += t_col; + outputs_cols.data[i + 1] = input_col; + outputs_data.data[i] = outputs[i].data<T>(); + } + + // computation + const int kThreadsPerBlock = 256; + int block_cols = std::min(input_col, kThreadsPerBlock); + int block_rows = std::max(kThreadsPerBlock / block_cols, 1); + dim3 block_size = dim3(block_cols, block_rows, 1); + + int grid_cols = (input_col + block_cols - 1) / block_cols; + int grid_rows = (input_row + block_rows - 1) / block_rows; + dim3 grid_size = dim3(grid_cols, grid_rows, 1); + + if (sameShape) { + KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>( + input.data<T>(), input_row, input_col, output_col_0, outputs_data); + } else { + KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>( + input.data<T>(), input_row, input_col, outputs_cols, outputs_data); + } } }; @@ -149,6 +270,11 @@ template class ConcatFunctor<platform::CUDADeviceContext, int64_t>; template class ConcatFunctor<platform::CUDADeviceContext, float>; template class ConcatFunctor<platform::CUDADeviceContext, double>; +template class ConcatGradFunctor<platform::CUDADeviceContext, int>; +template class ConcatGradFunctor<platform::CUDADeviceContext, int64_t>; +template class ConcatGradFunctor<platform::CUDADeviceContext, float>; +template class ConcatGradFunctor<platform::CUDADeviceContext, double>; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/concat.h b/paddle/fluid/operators/math/concat.h index 50c75dd208..bc87831888 100644 --- a/paddle/fluid/operators/math/concat.h +++ b/paddle/fluid/operators/math/concat.h @@ -20,18 +20,23 @@ namespace operators { namespace math { /* - * the tensor's shape of input will be changed, - * so the second parameter is not const. * */ template <typename DeviceContext, typename T> class ConcatFunctor { public: void operator()(const DeviceContext& context, - std::vector<framework::Tensor>& input, const int axis, + const std::vector<framework::Tensor>& input, const int axis, framework::Tensor* output); }; +template <typename DeviceContext, typename T> +class ConcatGradFunctor { + public: + void operator()(const DeviceContext& context, const framework::Tensor& input, + const int axis, std::vector<framework::Tensor>& outputs); +}; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/concat_test.cc b/paddle/fluid/operators/math/concat_test.cc index 815070b113..1741af8148 100644 --- a/paddle/fluid/operators/math/concat_test.cc +++ b/paddle/fluid/operators/math/concat_test.cc @@ -251,6 +251,80 @@ void testConcat() { } } } + + /** + * cast4: + * inputs: + * axis = 1 + * t_a.shape: [2, 3, 4] + * t_b.shape: [2, 3, 4] + * output: + * out.shape: [2, 6, 4] + */ + dim_a = make_ddim({2, 3, 4}); + dim_b = make_ddim({2, 3, 4}); + dim_out = make_ddim({2, 6, 4}); + + input_a.Resize(dim_a); + input_b.Resize(dim_b); + out.Resize(dim_out); + if (paddle::platform::is_gpu_place(Place())) { + input_a_cpu.Resize(dim_a); + input_b_cpu.Resize(dim_b); + out_cpu.Resize(dim_out); + } + + if (paddle::platform::is_gpu_place(Place())) { + a_ptr = input_a_cpu.data<int>(); + b_ptr = input_b_cpu.data<int>(); + } else { + a_ptr = input_a.data<int>(); + b_ptr = input_b.data<int>(); + } + + for (int i = 0; i < 2 * 3 * 4; ++i) { + a_ptr[i] = i; + } + for (int i = 0; i < 2 * 3 * 4; ++i) { + b_ptr[i] = i; + } + + if (paddle::platform::is_gpu_place(Place())) { + TensorCopy(input_a_cpu, Place(), *context, &input_a); + TensorCopy(input_b_cpu, Place(), *context, &input_b); + } + + input.clear(); + input.push_back(input_a); + input.push_back(input_b); + + concat_functor(*context, input, 1, &out); + + // check the dim of input_a, input_b + PADDLE_ENFORCE_EQ(input_a.dims(), dim_a); + PADDLE_ENFORCE_EQ(input_b.dims(), dim_b); + + if (paddle::platform::is_gpu_place(Place())) { + TensorCopy(out, CPUPlace(), *context, &out_cpu); + out_ptr = out_cpu.data<int>(); + } else { + out_ptr = out.data<int>(); + } + + // check the data + cols = 12; + idx_a = 0, idx_b = 0; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 24; ++j) { + if (j >= cols) { + PADDLE_ENFORCE_EQ(out_ptr[i * 24 + j], b_ptr[idx_b]); + ++idx_b; + } else { + PADDLE_ENFORCE_EQ(out_ptr[i * 24 + j], a_ptr[idx_a]); + ++idx_a; + } + } + } } TEST(math, concat) { diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index 05e1eae853..da4041bad0 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -33,6 +33,26 @@ int GetCUDADeviceCount() { return count; } +int GetCUDAMultiProcessors(int id) { + PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count"); + int count; + PADDLE_ENFORCE( + cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, id), + "cudaDeviceGetAttribute failed in " + "paddle::platform::GetCUDAMultiProcessors"); + return count; +} + +int GetCUDAMaxThreadsPerMultiProcessor(int id) { + PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count"); + int count; + PADDLE_ENFORCE(cudaDeviceGetAttribute( + &count, cudaDevAttrMaxThreadsPerMultiProcessor, id), + "cudaDeviceGetAttribute failed in " + "paddle::platform::GetCUDAMaxThreadsPerMultiProcessor"); + return count; +} + int GetCurrentDeviceId() { int device_id; PADDLE_ENFORCE( diff --git a/paddle/fluid/platform/gpu_info.h b/paddle/fluid/platform/gpu_info.h index 3d4883d807..c38ccf0f2a 100644 --- a/paddle/fluid/platform/gpu_info.h +++ b/paddle/fluid/platform/gpu_info.h @@ -30,6 +30,12 @@ const std::string kEnvFractionGpuMemoryToUse = //! Get the total number of GPU devices in system. int GetCUDADeviceCount(); +//! Get the MultiProcessors of the ith GPU. +int GetCUDAMultiProcessors(int i); + +//! Get the MaxThreads of each MultiProcessor of the ith GPU. +int GetCUDAMaxThreadsPerMultiProcessor(int i); + //! Get the current GPU device id in system. int GetCurrentDeviceId(); From 82bd82c186d0bb228f0f8add3f8089cd44f99b2c Mon Sep 17 00:00:00 2001 From: chengduoZH <zhaochengduo@163.com> Date: Mon, 5 Mar 2018 10:23:59 +0800 Subject: [PATCH 05/40] follow comments and refine code --- paddle/fluid/operators/concat_op.h | 2 + paddle/fluid/operators/math/concat.cc | 19 ++-- paddle/fluid/operators/math/concat.cu | 121 ++++++++++++-------------- paddle/fluid/operators/math/concat.h | 21 +++++ 4 files changed, 88 insertions(+), 75 deletions(-) diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h index a65b1987cb..6ac70eacaf 100644 --- a/paddle/fluid/operators/concat_op.h +++ b/paddle/fluid/operators/concat_op.h @@ -33,6 +33,7 @@ class ConcatKernel : public framework::OpKernel<T> { auto place = ctx.GetPlace(); out->mutable_data<T>(place); + // TODO(zcd): Sometimes direct copies will be faster std::vector<framework::Tensor> inputs(ins.size()); for (size_t j = 0; j < ins.size(); ++j) { inputs[j] = *ins[j]; @@ -51,6 +52,7 @@ class ConcatGradKernel : public framework::OpKernel<T> { auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X")); int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis")); + // TODO(zcd): Sometimes direct copies will be faster std::vector<framework::Tensor> outputs(outs.size()); for (size_t j = 0; j < outs.size(); ++j) { outs[j]->mutable_data<T>(ctx.GetPlace()); diff --git a/paddle/fluid/operators/math/concat.cc b/paddle/fluid/operators/math/concat.cc index 5c5c6489d6..b542143419 100644 --- a/paddle/fluid/operators/math/concat.cc +++ b/paddle/fluid/operators/math/concat.cc @@ -19,7 +19,8 @@ namespace operators { namespace math { /* - * All tensors' dimension should be the same. + * All tensors' dimension should be the same and the values of + * each dimension are the same, except the axis dimension. */ template <typename T> class ConcatFunctor<platform::CPUDeviceContext, T> { @@ -27,12 +28,9 @@ class ConcatFunctor<platform::CPUDeviceContext, T> { void operator()(const platform::CPUDeviceContext& context, const std::vector<framework::Tensor>& input, const int axis, framework::Tensor* output) { - // assume the the max size of input is less than 8 and see the performance - // save origin dim + // TODO(zcd): Add input data validity checking int num = input.size(); - std::vector<paddle::framework::DDim> origin_dim(num); - // get the matrix size int rows = 1; auto dim_0 = input[0].dims(); for (int i = 0; i < axis; ++i) { @@ -40,7 +38,6 @@ class ConcatFunctor<platform::CPUDeviceContext, T> { } int out_rows = rows, out_cols = 0; - // get input's cols std::vector<int64_t> input_cols(input.size()); for (int i = 0; i < num; ++i) { int t_cols = input[i].numel() / rows; @@ -64,18 +61,19 @@ class ConcatFunctor<platform::CPUDeviceContext, T> { } }; +/* + * All tensors' dimension should be the same and the values of + * each dimension are the same, except the axis dimension. + */ template <typename T> class ConcatGradFunctor<platform::CPUDeviceContext, T> { public: void operator()(const platform::CPUDeviceContext& context, const framework::Tensor& input, const int axis, std::vector<framework::Tensor>& outputs) { - // assume the the max size of input is less than 8 and see the performance - // save origin dim + // TODO(zcd): Add input data validity checking int num = outputs.size(); - std::vector<paddle::framework::DDim> origin_dim(num); - // get the matrix size int input_rows = 1; auto dim_0 = outputs[0].dims(); for (int i = 0; i < axis; ++i) { @@ -83,7 +81,6 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> { } int input_cols = 0; - // get outputs' cols std::vector<int64_t> output_cols(outputs.size()); for (int i = 0; i < num; ++i) { int t_cols = outputs[i].numel() / input_rows; diff --git a/paddle/fluid/operators/math/concat.cu b/paddle/fluid/operators/math/concat.cu index 8af7233426..5f64856a1a 100644 --- a/paddle/fluid/operators/math/concat.cu +++ b/paddle/fluid/operators/math/concat.cu @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/operators/math/concat.h" #include "paddle/fluid/platform/cuda_helper.h" @@ -19,16 +20,6 @@ namespace paddle { namespace operators { namespace math { -// TODO(zcd): This can be replaced by tensor, -// if that, maybe we should add int8 to VarType::Type. -// Or replaced by tensorArray. -static constexpr int MaxSize = 8; -template <typename T> -struct CUDADeviceArray { - T data[MaxSize]; - int size; -}; - template <typename T> __device__ T upper_bound(const T* first, T count, T val) { const T* orig = first; @@ -49,25 +40,24 @@ __device__ T upper_bound(const T* first, T count, T val) { } template <typename T> -__global__ void KernelConcat(const CUDADeviceArray<const T*> inputs, - const CUDADeviceArray<int> input_cols, +__global__ void KernelConcat(T** inputs, const int* input_cols, int col_size, const int output_rows, const int output_cols, T* output) { int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - int segment = upper_bound<int>(input_cols.data, input_cols.size, tid_x) - 1; + int segment = upper_bound<int>(input_cols, col_size, tid_x) - 1; - int curr_offset = input_cols.data[segment]; + int curr_offset = input_cols[segment]; int curr_segment = segment; for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { T curr_col_offset; - while ((curr_col_offset = input_cols.data[curr_segment + 1]) <= tid_x) { + while ((curr_col_offset = input_cols[curr_segment + 1]) <= tid_x) { curr_offset = curr_col_offset; ++curr_segment; } int local_col = tid_x - curr_offset; int segment_width = curr_col_offset - curr_offset; - const T* input_ptr = inputs.data[curr_segment]; + T* input_ptr = inputs[curr_segment]; int tid_y = blockIdx.y * blockDim.y + threadIdx.y; for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) output[tid_y * output_cols + tid_x] = @@ -76,41 +66,41 @@ __global__ void KernelConcat(const CUDADeviceArray<const T*> inputs, } template <typename T> -__global__ void KernelConcat(const CUDADeviceArray<const T*> inputs, - const int input_col, const int output_rows, - const int output_cols, T* output) { +__global__ void KernelConcat(T** inputs, const int input_col, + const int output_rows, const int output_cols, + T* output) { int tid_x = blockIdx.x * blockDim.x + threadIdx.x; float inv_input_col = 1.0 / input_col; for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { int split = tid_x * inv_input_col; int in_offset = tid_x - split * input_col; - const T* input_ptr = inputs.data[split]; + T* input_ptr = inputs[split]; int tid_y = blockIdx.y * blockDim.y + threadIdx.y; - for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) + for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) { output[tid_y * output_cols + tid_x] = input_ptr[tid_y * input_col + in_offset]; + } } } template <typename T> __global__ void KernelConcatGrad(const T* input, const int input_row, - const int input_col, - CUDADeviceArray<int> output_cols, - CUDADeviceArray<T*> outputs) { + const int input_col, const int* output_cols, + int col_size, T** outputs) { int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - int segment = upper_bound<int>(output_cols.data, output_cols.size, tid_x) - 1; - int curr_offset = output_cols.data[segment]; + int segment = upper_bound<int>(output_cols, col_size, tid_x) - 1; + int curr_offset = output_cols[segment]; int curr_segment = segment; for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) { T curr_col_offset; - while ((curr_col_offset = output_cols.data[curr_segment + 1]) <= tid_x) { + while ((curr_col_offset = output_cols[curr_segment + 1]) <= tid_x) { curr_offset = curr_col_offset; ++curr_segment; } int local_col = tid_x - curr_offset; int segment_width = curr_col_offset - curr_offset; - T* output_ptr = outputs.data[curr_segment]; + T* output_ptr = outputs[curr_segment]; int tid_y = blockIdx.y * blockDim.y + threadIdx.y; for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y) output_ptr[tid_y * segment_width + local_col] = @@ -121,13 +111,13 @@ __global__ void KernelConcatGrad(const T* input, const int input_row, template <typename T> __global__ void KernelConcatGrad(const T* input, const int input_row, const int input_col, const int output_cols, - CUDADeviceArray<T*> outputs) { + T** outputs) { int tid_x = blockIdx.x * blockDim.x + threadIdx.x; float inv_input_col = 1.0 / input_col; for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) { int split = tid_x * inv_input_col; int in_offset = tid_x - split * input_col; - T* output_ptr = outputs.data[split]; + T* output_ptr = outputs[split]; int tid_y = blockIdx.y * blockDim.y + threadIdx.y; for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y) output_ptr[tid_y * output_cols + in_offset] = @@ -136,7 +126,8 @@ __global__ void KernelConcatGrad(const T* input, const int input_row, } /* - * All tensors' dimension should be the same. + * All tensors' dimension should be the same and the values of + * each dimension are the same, except the axis dimension. */ template <typename T> class ConcatFunctor<platform::CUDADeviceContext, T> { @@ -144,12 +135,8 @@ class ConcatFunctor<platform::CUDADeviceContext, T> { void operator()(const platform::CUDADeviceContext& context, const std::vector<framework::Tensor>& input, const int axis, framework::Tensor* output) { - // assume the the max size of input is less than 8 and see the performance - // save origin dim + // TODO(zcd): Add input data validity checking int num = input.size(); - PADDLE_ENFORCE_LT(num, MaxSize, "input number should be less than %d", - MaxSize); - // get the matrix size int rows = 1; auto dim_0 = input[0].dims(); for (int i = 0; i < axis; ++i) { @@ -157,25 +144,27 @@ class ConcatFunctor<platform::CUDADeviceContext, T> { } int cols = input[0].numel() / rows; int out_rows = rows, out_cols = 0; - bool sameShape = true; - CUDADeviceArray<const T*> inputs_data; - CUDADeviceArray<int> inputs_cols; - inputs_data.size = num; - inputs_cols.size = num + 1; - inputs_cols.data[0] = 0; - // reshape to matrix - // check input shape is valid + paddle::framework::Vector<int16_t> inputs_data(num * sizeof(T*) / 2); + paddle::framework::Vector<int> inputs_cols(num + 1); + inputs_cols[0] = 0; + T** inputs_ptr = reinterpret_cast<T**>(inputs_data.data()); + + bool sameShape = true; for (int i = 0; i < num; ++i) { int t_cols = input[i].numel() / rows; if (sameShape) { if (t_cols != cols) sameShape = false; } out_cols += t_cols; - inputs_cols.data[i + 1] = out_cols; - inputs_data.data[i] = input[i].data<T>(); + inputs_cols[i + 1] = out_cols; + inputs_ptr[i] = const_cast<T*>(input[i].data<T>()); } + T** ins_gpu = + reinterpret_cast<T**>(inputs_data.CUDAMutableData(context.GetPlace())); + const int* ins_col_gpu = inputs_cols.CUDAData(context.GetPlace()); + // computation // set the thread block and grid according to CurrentDeviceId const int kThreadsPerBlock = 1024; @@ -198,27 +187,27 @@ class ConcatFunctor<platform::CUDADeviceContext, T> { if (sameShape) { KernelConcat<<<grid_size, block_size, 0, context.stream()>>>( - inputs_data, cols, out_rows, out_cols, output->data<T>()); + ins_gpu, cols, out_rows, out_cols, output->data<T>()); } else { KernelConcat<<<grid_size, block_size, 0, context.stream()>>>( - inputs_data, inputs_cols, out_rows, out_cols, output->data<T>()); + ins_gpu, ins_col_gpu, static_cast<int>(inputs_cols.size()), out_rows, + out_cols, output->data<T>()); } } }; +/* + * All tensors' dimension should be the same and the values of + * each dimension are the same, except the axis dimension. + */ template <typename T> class ConcatGradFunctor<platform::CUDADeviceContext, T> { public: void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& input, const int axis, std::vector<framework::Tensor>& outputs) { - // assume the the max size of input is less than 8 and see the performance - // save origin dim + // TODO(zcd): Add input data validity checking int num = outputs.size(); - PADDLE_ENFORCE_LT(num, MaxSize, "input number should be less than %d", - MaxSize); - - // get the matrix size int input_row = 1; auto dim_0 = outputs[0].dims(); for (int i = 0; i < axis; ++i) { @@ -229,11 +218,10 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> { int input_col = 0; bool sameShape = true; - CUDADeviceArray<T*> outputs_data; - CUDADeviceArray<int> outputs_cols; - outputs_data.size = num; - outputs_cols.size = num + 1; - outputs_cols.data[0] = 0; + paddle::framework::Vector<int16_t> outputs_data(num * sizeof(T*) / 2); + paddle::framework::Vector<int> outputs_cols(num + 1); + outputs_cols[0] = 0; + T** outputs_ptr = reinterpret_cast<T**>(outputs_data.data()); for (int i = 0; i < num; ++i) { int t_col = outputs[i].numel() / input_row; @@ -241,12 +229,16 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> { if (t_col != output_col_0) sameShape = false; } input_col += t_col; - outputs_cols.data[i + 1] = input_col; - outputs_data.data[i] = outputs[i].data<T>(); + outputs_cols[i + 1] = input_col; + outputs_ptr[i] = outputs[i].data<T>(); } + T** outs_gpu = + reinterpret_cast<T**>(outputs_data.CUDAMutableData(context.GetPlace())); + const int* outs_col_gpu = outputs_cols.CUDAData(context.GetPlace()); + // computation - const int kThreadsPerBlock = 256; + const int kThreadsPerBlock = 1024; int block_cols = std::min(input_col, kThreadsPerBlock); int block_rows = std::max(kThreadsPerBlock / block_cols, 1); dim3 block_size = dim3(block_cols, block_rows, 1); @@ -257,10 +249,11 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> { if (sameShape) { KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>( - input.data<T>(), input_row, input_col, output_col_0, outputs_data); + input.data<T>(), input_row, input_col, output_col_0, outs_gpu); } else { KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>( - input.data<T>(), input_row, input_col, outputs_cols, outputs_data); + input.data<T>(), input_row, input_col, outs_col_gpu, + static_cast<int>(outputs_cols.size()), outs_gpu); } } }; diff --git a/paddle/fluid/operators/math/concat.h b/paddle/fluid/operators/math/concat.h index bc87831888..22147d79e4 100644 --- a/paddle/fluid/operators/math/concat.h +++ b/paddle/fluid/operators/math/concat.h @@ -20,7 +20,16 @@ namespace operators { namespace math { /* + * \brief Concatenate the input tensors along the dimension axis. + * TODO(zcd): maybe it needs to be more detailed. + * Examples: + * Input[0] = [[1,2],[3,4]] + * Input[1] = [[5,6]] + * axis = 0 * + * Output = [[1,2], + * [3,4], + * [5,6]] */ template <typename DeviceContext, typename T> class ConcatFunctor { @@ -30,6 +39,18 @@ class ConcatFunctor { framework::Tensor* output); }; +/* + * \brief Split the input tensors along the dimension axis into outputs. + * TODO(zcd): maybe it needs to be more detailed. + * Examples: + * Input = [[1,2], + * [3,4], + * [5,6]] + * axis = 0 + * + * Output[0] = [[1,2],[3,4]] + * Output[1] = [[5,6]] + */ template <typename DeviceContext, typename T> class ConcatGradFunctor { public: From 131ec276edbaee8cd571a244b3885e03c9176788 Mon Sep 17 00:00:00 2001 From: chengduoZH <zhaochengduo@163.com> Date: Mon, 5 Mar 2018 22:38:57 +0800 Subject: [PATCH 06/40] fix bug for big number; float->double and code refine --- paddle/fluid/operators/math/concat.cu | 41 +++++++++++++++---------- paddle/fluid/platform/device_context.cc | 6 ++++ paddle/fluid/platform/device_context.h | 6 ++++ 3 files changed, 36 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/operators/math/concat.cu b/paddle/fluid/operators/math/concat.cu index 5f64856a1a..60b266f08f 100644 --- a/paddle/fluid/operators/math/concat.cu +++ b/paddle/fluid/operators/math/concat.cu @@ -70,7 +70,7 @@ __global__ void KernelConcat(T** inputs, const int input_col, const int output_rows, const int output_cols, T* output) { int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - float inv_input_col = 1.0 / input_col; + double inv_input_col = 1.0 / input_col; for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { int split = tid_x * inv_input_col; int in_offset = tid_x - split * input_col; @@ -113,7 +113,7 @@ __global__ void KernelConcatGrad(const T* input, const int input_row, const int input_col, const int output_cols, T** outputs) { int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - float inv_input_col = 1.0 / input_col; + double inv_input_col = 1.0 / input_col; for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) { int split = tid_x * inv_input_col; int in_offset = tid_x - split * input_col; @@ -145,8 +145,8 @@ class ConcatFunctor<platform::CUDADeviceContext, T> { int cols = input[0].numel() / rows; int out_rows = rows, out_cols = 0; - paddle::framework::Vector<int16_t> inputs_data(num * sizeof(T*) / 2); - paddle::framework::Vector<int> inputs_cols(num + 1); + framework::Vector<int16_t> inputs_data(num * sizeof(T*) / 2); + framework::Vector<int> inputs_cols(num + 1); inputs_cols[0] = 0; T** inputs_ptr = reinterpret_cast<T**>(inputs_data.data()); @@ -168,15 +168,14 @@ class ConcatFunctor<platform::CUDADeviceContext, T> { // computation // set the thread block and grid according to CurrentDeviceId const int kThreadsPerBlock = 1024; - int block_cols = std::min(out_cols, kThreadsPerBlock); - int block_rows = std::max(kThreadsPerBlock / block_cols, 1); + int block_cols = kThreadsPerBlock; + if (out_cols < kThreadsPerBlock) { // block_cols is aligned by 32. + block_cols = ((out_cols + 31) >> 5) << 5; + } + int block_rows = kThreadsPerBlock / block_cols; dim3 block_size = dim3(block_cols, block_rows, 1); - int dev_id = paddle::platform::GetCurrentDeviceId(); - int multi_process = paddle::platform::GetCUDAMultiProcessors(dev_id); - int max_threads_per_mp = - paddle::platform::GetCUDAMaxThreadsPerMultiProcessor(dev_id); - int max_threads = multi_process * max_threads_per_mp; + int max_threads = context.GetMaxPhysicalThreadCount(); int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); int grid_cols = @@ -218,8 +217,8 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> { int input_col = 0; bool sameShape = true; - paddle::framework::Vector<int16_t> outputs_data(num * sizeof(T*) / 2); - paddle::framework::Vector<int> outputs_cols(num + 1); + framework::Vector<int16_t> outputs_data(num * sizeof(T*) / 2); + framework::Vector<int> outputs_cols(num + 1); outputs_cols[0] = 0; T** outputs_ptr = reinterpret_cast<T**>(outputs_data.data()); @@ -239,12 +238,20 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> { // computation const int kThreadsPerBlock = 1024; - int block_cols = std::min(input_col, kThreadsPerBlock); - int block_rows = std::max(kThreadsPerBlock / block_cols, 1); + int block_cols = kThreadsPerBlock; + if (input_col < kThreadsPerBlock) { // block_cols is aligned by 32. + block_cols = ((input_col + 31) >> 5) << 5; + } + int block_rows = kThreadsPerBlock / block_cols; dim3 block_size = dim3(block_cols, block_rows, 1); - int grid_cols = (input_col + block_cols - 1) / block_cols; - int grid_rows = (input_row + block_rows - 1) / block_rows; + int max_threads = context.GetMaxPhysicalThreadCount(); + int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); + + int grid_cols = + std::min((input_col + block_cols - 1) / block_cols, max_blocks); + int grid_rows = + std::min(max_blocks / grid_cols, std::max(input_row / block_rows, 1)); dim3 grid_size = dim3(grid_cols, grid_rows, 1); if (sameShape) { diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 7da6e04d0a..583a3e740e 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -121,6 +121,8 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) { SetDeviceId(place_.device); + multi_process = GetCUDAMultiProcessors(place_.device); + max_threads_per_mp = GetCUDAMaxThreadsPerMultiProcessor(place_.device); PADDLE_ENFORCE(cudaStreamCreate(&stream_)); eigen_stream_.reset(new EigenCudaStreamDevice()); eigen_stream_->Reinitialize(&stream_, place); @@ -154,6 +156,10 @@ void CUDADeviceContext::Wait() const { PADDLE_ENFORCE(cudaGetLastError()); } +int CUDADeviceContext::GetMaxPhysicalThreadCount() const { + return multi_process * max_threads_per_mp; +} + Eigen::GpuDevice* CUDADeviceContext::eigen_device() const { return eigen_device_.get(); } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index a294ba5101..918243ccfe 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -79,6 +79,9 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return place in the device context. */ Place GetPlace() const override; + /*! \brief Return the max physical thread count in the device context */ + int GetMaxPhysicalThreadCount() const; + /*! \brief Return eigen device in the device context. */ Eigen::GpuDevice* eigen_device() const; @@ -100,6 +103,9 @@ class CUDADeviceContext : public DeviceContext { cudaStream_t stream_; cudnnHandle_t cudnn_handle_; cublasHandle_t cublas_handle_; + + int multi_process; + int max_threads_per_mp; }; template <> From c3864eab994ffacfe52c5c4477019268263f473e Mon Sep 17 00:00:00 2001 From: chengduoZH <zhaochengduo@163.com> Date: Mon, 5 Mar 2018 23:50:36 +0800 Subject: [PATCH 07/40] if axis == 0; directly copy D->D --- paddle/fluid/operators/concat_op.h | 60 +++++++++++++++++++++--------- 1 file changed, 43 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h index 6ac70eacaf..92c8ab6d9f 100644 --- a/paddle/fluid/operators/concat_op.h +++ b/paddle/fluid/operators/concat_op.h @@ -33,14 +33,26 @@ class ConcatKernel : public framework::OpKernel<T> { auto place = ctx.GetPlace(); out->mutable_data<T>(place); - // TODO(zcd): Sometimes direct copies will be faster - std::vector<framework::Tensor> inputs(ins.size()); - for (size_t j = 0; j < ins.size(); ++j) { - inputs[j] = *ins[j]; + // Sometimes direct copies will be faster, this maybe need deeply analysis. + if (axis == 0 && ins.size() < 10) { + size_t output_offset = 0; + for (auto* in : ins) { + auto in_stride = framework::stride_numel(in->dims()); + auto out_stride = framework::stride_numel(out->dims()); + StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, + out->data<T>() + output_offset, out_stride, + in->data<T>(), in_stride, in_stride[axis]); + output_offset += in_stride[axis]; + } + } else { + std::vector<framework::Tensor> inputs(ins.size()); + for (size_t j = 0; j < ins.size(); ++j) { + inputs[j] = *ins[j]; + } + auto& dev_ctx = ctx.template device_context<DeviceContext>(); + paddle::operators::math::ConcatFunctor<DeviceContext, T> concat_functor; + concat_functor(dev_ctx, inputs, static_cast<int>(axis), out); } - auto& dev_ctx = ctx.template device_context<DeviceContext>(); - paddle::operators::math::ConcatFunctor<DeviceContext, T> concat_functor; - concat_functor(dev_ctx, inputs, static_cast<int>(axis), out); } }; @@ -52,17 +64,31 @@ class ConcatGradKernel : public framework::OpKernel<T> { auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X")); int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis")); - // TODO(zcd): Sometimes direct copies will be faster - std::vector<framework::Tensor> outputs(outs.size()); - for (size_t j = 0; j < outs.size(); ++j) { - outs[j]->mutable_data<T>(ctx.GetPlace()); - outputs[j] = *outs[j]; - } + // Sometimes direct copies will be faster, this maybe need deeply analysis. + if (axis == 0 && outs.size() < 10) { + size_t input_offset = 0; + auto in_stride = framework::stride_numel(in->dims()); + + for (auto& out : outs) { + out->mutable_data<T>(ctx.GetPlace()); + auto out_stride = framework::stride_numel(out->dims()); + StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(), + out_stride, in->data<T>() + input_offset, + in_stride, out_stride[axis]); + input_offset += out_stride[axis]; + } + } else { + std::vector<framework::Tensor> outputs(outs.size()); + for (size_t j = 0; j < outs.size(); ++j) { + outs[j]->mutable_data<T>(ctx.GetPlace()); + outputs[j] = *outs[j]; + } - auto& dev_ctx = ctx.template device_context<DeviceContext>(); - paddle::operators::math::ConcatGradFunctor<DeviceContext, T> - concat_grad_functor; - concat_grad_functor(dev_ctx, *in, static_cast<int>(axis), outputs); + auto& dev_ctx = ctx.template device_context<DeviceContext>(); + paddle::operators::math::ConcatGradFunctor<DeviceContext, T> + concat_grad_functor; + concat_grad_functor(dev_ctx, *in, static_cast<int>(axis), outputs); + } } }; From cf6244c1b819648134ffd08d4fa2e1ac99535427 Mon Sep 17 00:00:00 2001 From: Xin Pan <panxin.grad@gmail.com> Date: Mon, 5 Mar 2018 23:17:39 -0800 Subject: [PATCH 08/40] Improve profiler smaller binary proto avoid untrackable kernel --- paddle/fluid/platform/device_tracer.cc | 17 ++++++++++++----- paddle/fluid/platform/profiler.proto | 2 +- tools/timeline.py | 2 +- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/platform/device_tracer.cc b/paddle/fluid/platform/device_tracer.cc index 265343573b..6efe703e22 100644 --- a/paddle/fluid/platform/device_tracer.cc +++ b/paddle/fluid/platform/device_tracer.cc @@ -193,20 +193,27 @@ class DeviceTracerImpl : public DeviceTracer { void AddCPURecords(const char *anno, uint64_t start_ns, uint64_t end_ns) { std::lock_guard<std::mutex> l(trace_mu_); - cpu_records_.push_back( - CPURecord{anno, start_ns, end_ns, - std::hash<std::thread::id>{}(std::this_thread::get_id())}); + cpu_records_.push_back(CPURecord{anno, start_ns, end_ns, 0}); } void AddMemRecords(const std::string &name, uint64_t start_ns, uint64_t end_ns, uint32_t device_id, uint32_t stream_id, uint32_t correlation_id, uint64_t bytes) { + // 0 means timestamp information could not be collected for the kernel. + if (start_ns == 0 || end_ns == 0) { + return; + } + std::lock_guard<std::mutex> l(trace_mu_); mem_records_.push_back(MemRecord{name, start_ns, end_ns, device_id, stream_id, correlation_id, bytes}); } void AddKernelRecords(uint64_t start, uint64_t end, uint32_t device_id, uint32_t stream_id, uint32_t correlation_id) { + // 0 means timestamp information could not be collected for the kernel. + if (start == 0 || end == 0) { + return; + } std::lock_guard<std::mutex> l(trace_mu_); kernel_records_.push_back( KernelRecord{start, end, device_id, stream_id, correlation_id}); @@ -279,10 +286,10 @@ class DeviceTracerImpl : public DeviceTracer { event->set_device_id(r.device_id); event->mutable_memcopy()->set_bytes(r.bytes); } - std::string profile_str; - google::protobuf::TextFormat::PrintToString(profile_pb, &profile_str); std::ofstream profile_f; profile_f.open(profile_path, std::ios::out | std::ios::trunc); + std::string profile_str; + profile_pb.SerializeToString(&profile_str); profile_f << profile_str; profile_f.close(); return profile_pb; diff --git a/paddle/fluid/platform/profiler.proto b/paddle/fluid/platform/profiler.proto index 06db7ed638..71b5a9b12e 100644 --- a/paddle/fluid/platform/profiler.proto +++ b/paddle/fluid/platform/profiler.proto @@ -15,7 +15,7 @@ limitations under the License. */ syntax = "proto2"; package paddle.platform.proto; -message MemCopy { optional uint64 bytes = 3; } +message MemCopy { optional uint64 bytes = 1; } message Event { optional string name = 1; diff --git a/tools/timeline.py b/tools/timeline.py index d1d1dae2bd..ee83a1baec 100644 --- a/tools/timeline.py +++ b/tools/timeline.py @@ -159,7 +159,7 @@ if args.timeline_path: with open(profile_path, 'r') as f: profile_s = f.read() profile_pb = profiler_pb2.Profile() - text_format.Merge(profile_s, profile_pb) + profile_pb.ParseFromString(profile_s) tl = Timeline(profile_pb) with open(timeline_path, 'w') as f: From 8b30fadac3c0a9acec72937f330328dbbe1e9305 Mon Sep 17 00:00:00 2001 From: chengduoZH <zhaochengduo@163.com> Date: Wed, 7 Mar 2018 10:21:47 +0800 Subject: [PATCH 09/40] refine elementwise sub,div,min,max --- paddle/fluid/operators/elementwise_div_op.h | 79 ++------------------- paddle/fluid/operators/elementwise_max_op.h | 79 +++------------------ paddle/fluid/operators/elementwise_min_op.h | 79 +++------------------ paddle/fluid/operators/elementwise_mul_op.h | 11 ++- paddle/fluid/operators/elementwise_sub_op.h | 63 ++-------------- 5 files changed, 34 insertions(+), 277 deletions(-) diff --git a/paddle/fluid/operators/elementwise_div_op.h b/paddle/fluid/operators/elementwise_div_op.h index 6bcc577456..95649ac46e 100644 --- a/paddle/fluid/operators/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise_div_op.h @@ -41,77 +41,14 @@ class ElementwiseDivKernel : public framework::OpKernel<T> { }; template <typename T> -struct ElementwiseDivGradFunctor { - template <typename Device, typename X, typename Y, typename Z, typename dX, - typename dY, typename dZ> - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) { - auto y_e = framework::EigenVector<T>::Flatten(*y); - auto z_e = framework::EigenVector<T>::Flatten(*z); - auto dz_e = framework::EigenVector<T>::Flatten(*dz); - - if (dx) { - auto dx_e = framework::EigenVector<T>::Flatten(*dx); - dx_e.device(d) = dz_e / y_e; - } - - if (dy) { - auto dy_e = framework::EigenVector<T>::Flatten(*dy); - dy_e.device(d) = -1.0 * dz_e * z_e / y_e; - } - } -}; - -template <typename T> -struct ElementwiseDivBroadCastGradFunctor { - template <typename Device, typename X, typename Y, typename Z, typename dX, - typename dY, typename dZ, typename Pre, typename N> - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n) { - auto x_e = framework::EigenVector<T>::Flatten(*x); - auto y_e = framework::EigenVector<T>::Flatten(*y); - auto dz_e = framework::EigenVector<T>::Flatten(*dz); - - auto y_e_bcast = y_e.reshape(Eigen::DSizes<int, 2>(1, n)) - .broadcast(Eigen::DSizes<int, 2>(pre, 1)) - .reshape(Eigen::DSizes<int, 1>(x_e.size())); - - if (dx) { - auto dx_e = framework::EigenVector<T>::Flatten(*dx); - dx_e.device(d) = dz_e / y_e_bcast; - } - - if (dy) { - auto dy_e = framework::EigenVector<T>::Flatten(*dy); - dy_e.device(d) = (-1.0 * (x_e * dz_e) / (y_e_bcast * y_e_bcast)) - .reshape(Eigen::DSizes<int, 2>(pre, n)) - .sum(Eigen::array<int, 1>{{0}}); - } - } +struct DivGradDX { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout / y; } }; template <typename T> -struct ElementwiseDivBroadCast2GradFunctor { - template <typename Device, typename X, typename Y, typename Z, typename dX, - typename dY, typename dZ, typename Pre, typename N, typename Post> - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n, - Post post) { - auto x_e = framework::EigenVector<T>::Flatten(*x); - auto y_e = framework::EigenVector<T>::Flatten(*y); - auto dz_e = framework::EigenVector<T>::Flatten(*dz); - - auto y_e_bcast = y_e.reshape(Eigen::DSizes<int, 3>(1, n, 1)) - .broadcast(Eigen::DSizes<int, 3>(pre, 1, post)) - .reshape(Eigen::DSizes<int, 1>(x_e.size())); - if (dx) { - auto dx_e = framework::EigenVector<T>::Flatten(*dx); - dx_e.device(d) = dz_e / y_e_bcast; - } - - if (dy) { - auto dy_e = framework::EigenVector<T>::Flatten(*dy); - dy_e.device(d) = (-1.0 * (x_e * dz_e) / (y_e_bcast * y_e_bcast)) - .reshape(Eigen::DSizes<int, 3>(pre, n, post)) - .sum(Eigen::array<int, 2>{{0, 2}}); - } +struct DivGradDY { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { + return -dout * x / (y * y); } }; @@ -128,10 +65,8 @@ class ElementwiseDivGradKernel : public framework::OpKernel<T> { auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); int axis = ctx.Attr<int>("axis"); - ElementwiseGradCompute<DeviceContext, T, ElementwiseDivGradFunctor<T>, - ElementwiseDivBroadCastGradFunctor<T>, - ElementwiseDivBroadCast2GradFunctor<T>>( - ctx, x, y, out, dout, axis, dx, dy); + ElemwiseGradCompute<DeviceContext, T, DivGradDX<T>, DivGradDY<T>>( + ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX<T>(), DivGradDY<T>()); } }; diff --git a/paddle/fluid/operators/elementwise_max_op.h b/paddle/fluid/operators/elementwise_max_op.h index ab3a3d5827..527a18ee3b 100644 --- a/paddle/fluid/operators/elementwise_max_op.h +++ b/paddle/fluid/operators/elementwise_max_op.h @@ -41,76 +41,16 @@ class ElementwiseMaxKernel : public framework::OpKernel<T> { }; template <typename T> -struct ElementwiseMaxGradFunctor { - template <typename Device, typename X, typename Y, typename Z, typename dX, - typename dY, typename dZ> - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) { - auto x_e = framework::EigenVector<T>::Flatten(*x); - auto y_e = framework::EigenVector<T>::Flatten(*y); - auto dz_e = framework::EigenVector<T>::Flatten(*dz); - - if (dx) { - auto dx_e = framework::EigenVector<T>::Flatten(*dx); - dx_e.device(d) = (x_e > y_e).template cast<T>() * dz_e; - } - if (dy) { - auto dy_e = framework::EigenVector<T>::Flatten(*dy); - dy_e.device(d) = (x_e <= y_e).template cast<T>() * dz_e; - } +struct MaxGradDx { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { + return dout * (x > y); } }; template <typename T> -struct ElementwiseMaxBroadCastGradFunctor { - template <typename Device, typename X, typename Y, typename Z, typename dX, - typename dY, typename dZ, typename Pre, typename N> - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n) { - auto x_e = framework::EigenVector<T>::Flatten(*x); - auto y_e = framework::EigenVector<T>::Flatten(*y); - auto dz_e = framework::EigenVector<T>::Flatten(*dz); - - auto y_e_bcast = y_e.reshape(Eigen::DSizes<int, 2>(1, n)) - .broadcast(Eigen::DSizes<int, 2>(pre, 1)) - .reshape(Eigen::DSizes<int, 1>(x_e.size())); - - if (dx) { - auto dx_e = framework::EigenVector<T>::Flatten(*dx); - dx_e.device(d) = (x_e > y_e_bcast).template cast<T>() * dz_e; - } - - if (dy) { - auto dy_e = framework::EigenVector<T>::Flatten(*dy); - dy_e.device(d) = ((x_e <= y_e_bcast).template cast<T>() * dz_e) - .reshape(Eigen::DSizes<int, 2>(pre, n)) - .sum(Eigen::array<int, 1>{{0}}); - } - } -}; - -template <typename T> -struct ElementwiseMaxBroadCast2GradFunctor { - template <typename Device, typename X, typename Y, typename Z, typename dX, - typename dY, typename dZ, typename Pre, typename N, typename Post> - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n, - Post post) { - auto x_e = framework::EigenVector<T>::Flatten(*x); - auto y_e = framework::EigenVector<T>::Flatten(*y); - auto dz_e = framework::EigenVector<T>::Flatten(*dz); - - auto y_e_bcast = y_e.reshape(Eigen::DSizes<int, 3>(1, n, 1)) - .broadcast(Eigen::DSizes<int, 3>(pre, 1, post)) - .reshape(Eigen::DSizes<int, 1>(x_e.size())); - if (dx) { - auto dx_e = framework::EigenVector<T>::Flatten(*dx); - dx_e.device(d) = (x_e > y_e_bcast).template cast<T>() * dz_e; - } - - if (dy) { - auto dy_e = framework::EigenVector<T>::Flatten(*dy); - dy_e.device(d) = ((x_e <= y_e_bcast).template cast<T>() * dz_e) - .reshape(Eigen::DSizes<int, 3>(pre, n, post)) - .sum(Eigen::array<int, 2>{{0, 2}}); - } +struct MaxGradDy { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { + return dout * (x <= y); } }; @@ -127,12 +67,9 @@ class ElementwiseMaxGradKernel : public framework::OpKernel<T> { auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); int axis = ctx.Attr<int>("axis"); - ElementwiseGradCompute<DeviceContext, T, ElementwiseMaxGradFunctor<T>, - ElementwiseMaxBroadCastGradFunctor<T>, - ElementwiseMaxBroadCast2GradFunctor<T>>( - ctx, x, y, out, dout, axis, dx, dy); + ElemwiseGradCompute<DeviceContext, T, MaxGradDx<T>, MaxGradDy<T>>( + ctx, *x, *y, *out, *dout, axis, dx, dy, MaxGradDx<T>(), MaxGradDy<T>()); } }; - } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise_min_op.h b/paddle/fluid/operators/elementwise_min_op.h index f0eec9d246..d4e5831463 100644 --- a/paddle/fluid/operators/elementwise_min_op.h +++ b/paddle/fluid/operators/elementwise_min_op.h @@ -41,76 +41,16 @@ class ElementwiseMinKernel : public framework::OpKernel<T> { }; template <typename T> -struct ElementwiseMinGradFunctor { - template <typename Device, typename X, typename Y, typename Z, typename dX, - typename dY, typename dZ> - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) { - auto x_e = framework::EigenVector<T>::Flatten(*x); - auto y_e = framework::EigenVector<T>::Flatten(*y); - auto dz_e = framework::EigenVector<T>::Flatten(*dz); - - if (dx) { - auto dx_e = framework::EigenVector<T>::Flatten(*dx); - dx_e.device(d) = (x_e < y_e).template cast<T>() * dz_e; - } - if (dy) { - auto dy_e = framework::EigenVector<T>::Flatten(*dy); - dy_e.device(d) = (x_e >= y_e).template cast<T>() * dz_e; - } +struct MinGradDx { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { + return dout * (x < y); } }; template <typename T> -struct ElementwiseMinBroadCastGradFunctor { - template <typename Device, typename X, typename Y, typename Z, typename dX, - typename dY, typename dZ, typename Pre, typename N> - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n) { - auto x_e = framework::EigenVector<T>::Flatten(*x); - auto y_e = framework::EigenVector<T>::Flatten(*y); - auto dz_e = framework::EigenVector<T>::Flatten(*dz); - - auto y_e_bcast = y_e.reshape(Eigen::DSizes<int, 2>(1, n)) - .broadcast(Eigen::DSizes<int, 2>(pre, 1)) - .reshape(Eigen::DSizes<int, 1>(x_e.size())); - - if (dx) { - auto dx_e = framework::EigenVector<T>::Flatten(*dx); - dx_e.device(d) = (x_e < y_e_bcast).template cast<T>() * dz_e; - } - - if (dy) { - auto dy_e = framework::EigenVector<T>::Flatten(*dy); - dy_e.device(d) = ((x_e >= y_e_bcast).template cast<T>() * dz_e) - .reshape(Eigen::DSizes<int, 2>(pre, n)) - .sum(Eigen::array<int, 1>{{0}}); - } - } -}; - -template <typename T> -struct ElementwiseMinBroadCast2GradFunctor { - template <typename Device, typename X, typename Y, typename Z, typename dX, - typename dY, typename dZ, typename Pre, typename N, typename Post> - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n, - Post post) { - auto x_e = framework::EigenVector<T>::Flatten(*x); - auto y_e = framework::EigenVector<T>::Flatten(*y); - auto dz_e = framework::EigenVector<T>::Flatten(*dz); - - auto y_e_bcast = y_e.reshape(Eigen::DSizes<int, 3>(1, n, 1)) - .broadcast(Eigen::DSizes<int, 3>(pre, 1, post)) - .reshape(Eigen::DSizes<int, 1>(x_e.size())); - if (dx) { - auto dx_e = framework::EigenVector<T>::Flatten(*dx); - dx_e.device(d) = (x_e < y_e_bcast).template cast<T>() * dz_e; - } - - if (dy) { - auto dy_e = framework::EigenVector<T>::Flatten(*dy); - dy_e.device(d) = ((x_e >= y_e_bcast).template cast<T>() * dz_e) - .reshape(Eigen::DSizes<int, 3>(pre, n, post)) - .sum(Eigen::array<int, 2>{{0, 2}}); - } +struct MinGradDy { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { + return dout * (x >= y); } }; @@ -127,12 +67,9 @@ class ElementwiseMinGradKernel : public framework::OpKernel<T> { auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); int axis = ctx.Attr<int>("axis"); - ElementwiseGradCompute<DeviceContext, T, ElementwiseMinGradFunctor<T>, - ElementwiseMinBroadCastGradFunctor<T>, - ElementwiseMinBroadCast2GradFunctor<T>>( - ctx, x, y, out, dout, axis, dx, dy); + ElemwiseGradCompute<DeviceContext, T, MinGradDx<T>, MinGradDy<T>>( + ctx, *x, *y, *out, *dout, axis, dx, dy, MinGradDx<T>(), MinGradDy<T>()); } }; - } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise_mul_op.h b/paddle/fluid/operators/elementwise_mul_op.h index e2b59b3112..dc73cb6f23 100644 --- a/paddle/fluid/operators/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise_mul_op.h @@ -40,14 +40,15 @@ class ElementwiseMulKernel : public framework::OpKernel<T> { }; template <typename T> -struct IdentityGrad_DX { +struct MulGradDX { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * y; } }; template <typename T> -struct IdentityGrad_DY { +struct MulGradDY { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * x; } }; + template <typename DeviceContext, typename T> class ElementwiseMulGradKernel : public framework::OpKernel<T> { public: @@ -61,10 +62,8 @@ class ElementwiseMulGradKernel : public framework::OpKernel<T> { auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); int axis = ctx.Attr<int>("axis"); - ElemwiseGradCompute<DeviceContext, T, IdentityGrad_DX<T>, - IdentityGrad_DY<T>>(ctx, *x, *y, *out, *dout, axis, dx, - dy, IdentityGrad_DX<T>(), - IdentityGrad_DY<T>()); + ElemwiseGradCompute<DeviceContext, T, MulGradDX<T>, MulGradDY<T>>( + ctx, *x, *y, *out, *dout, axis, dx, dy, MulGradDX<T>(), MulGradDY<T>()); } }; } // namespace operators diff --git a/paddle/fluid/operators/elementwise_sub_op.h b/paddle/fluid/operators/elementwise_sub_op.h index a8fc242ed7..fe088b8203 100644 --- a/paddle/fluid/operators/elementwise_sub_op.h +++ b/paddle/fluid/operators/elementwise_sub_op.h @@ -40,61 +40,13 @@ class ElementwiseSubKernel : public framework::OpKernel<T> { }; template <typename T> -struct ElementwiseSubGradFunctor { - template <typename Device, typename X, typename Y, typename Z, typename dX, - typename dY, typename dZ> - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) { - auto dz_e = framework::EigenVector<T>::Flatten(*dz); - if (dx) { - auto dx_e = framework::EigenVector<T>::Flatten(*dx); - dx_e.device(d) = dz_e; - } - if (dy) { - auto dy_e = framework::EigenVector<T>::Flatten(*dy); - dy_e.device(d) = (-1.0) * dz_e; - } - } +struct SubGradDX { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; } }; template <typename T> -struct ElementwiseSubBroadCastGradFunctor { - template <typename Device, typename X, typename Y, typename Z, typename dX, - typename dY, typename dZ, typename Pre, typename N> - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n) { - auto dz_e = framework::EigenVector<T>::Flatten(*dz); - if (dx) { - auto dx_e = framework::EigenVector<T>::Flatten(*dx); - dx_e.device(d) = dz_e; - } - - if (dy) { - auto dy_e = framework::EigenVector<T>::Flatten(*dy); - dy_e.device(d) = (-1.0) * - dz_e.reshape(Eigen::DSizes<int, 2>(pre, n)) - .sum(Eigen::array<int, 1>{{0}}); - } - } -}; - -template <typename T> -struct ElementwiseSubBroadCast2GradFunctor { - template <typename Device, typename X, typename Y, typename Z, typename dX, - typename dY, typename dZ, typename Pre, typename N, typename Post> - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n, - Post post) { - auto dz_e = framework::EigenVector<T>::Flatten(*dz); - if (dx) { - auto dx_e = framework::EigenVector<T>::Flatten(*dx); - dx_e.device(d) = dz_e; - } - - if (dy) { - auto dy_e = framework::EigenVector<T>::Flatten(*dy); - dy_e.device(d) = (-1.0) * - dz_e.reshape(Eigen::DSizes<int, 3>(pre, n, post)) - .sum(Eigen::array<int, 2>{{0, 2}}); - } - } +struct SubGradDY { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return -dout; } }; template <typename DeviceContext, typename T> @@ -110,12 +62,9 @@ class ElementwiseSubGradKernel : public framework::OpKernel<T> { auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); int axis = ctx.Attr<int>("axis"); - ElementwiseGradCompute<DeviceContext, T, ElementwiseSubGradFunctor<T>, - ElementwiseSubBroadCastGradFunctor<T>, - ElementwiseSubBroadCast2GradFunctor<T>>( - ctx, x, y, out, dout, axis, dx, dy); + ElemwiseGradCompute<DeviceContext, T, SubGradDX<T>, SubGradDY<T>>( + ctx, *x, *y, *out, *dout, axis, dx, dy, SubGradDX<T>(), SubGradDY<T>()); } }; - } // namespace operators } // namespace paddle From 049383c615ff6d3ecd9f15e246b8d3c688f05b4d Mon Sep 17 00:00:00 2001 From: Yan Chunwei <yanchunwei@outlook.com> Date: Wed, 7 Mar 2018 12:39:07 +0800 Subject: [PATCH 10/40] add inplace to reshape (#8747) --- paddle/fluid/operators/reshape_op.cc | 3 ++ paddle/fluid/operators/reshape_op.h | 22 +++++++++++---- .../fluid/tests/unittests/test_reshape_op.py | 28 +++++++++++++++++++ 3 files changed, 48 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 3580932356..832509641c 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -84,6 +84,9 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr<std::vector<int>>("shape", "(vector<int>) " "Target shape of reshape operator."); + AddAttr<bool>("inplace", + "Change the source tensor's shape without copy memory.") + .SetDefault(true); AddComment(R"DOC( Reshape Operator. diff --git a/paddle/fluid/operators/reshape_op.h b/paddle/fluid/operators/reshape_op.h index 1357bce4b7..eacb0a0cf2 100644 --- a/paddle/fluid/operators/reshape_op.h +++ b/paddle/fluid/operators/reshape_op.h @@ -26,10 +26,16 @@ class ReshapeKernel : public framework::OpKernel<T> { void Compute(const framework::ExecutionContext& ctx) const { auto* out = ctx.Output<framework::Tensor>("Out"); auto* in = ctx.Input<framework::Tensor>("X"); + bool inplace = ctx.Attr<bool>("inplace"); auto out_dims = out->dims(); - out->mutable_data<T>(ctx.GetPlace()); - framework::TensorCopy(*in, ctx.GetPlace(), ctx.device_context(), out); - out->Resize(out_dims); + if (!inplace) { + out->mutable_data<T>(ctx.GetPlace()); + framework::TensorCopy(*in, ctx.GetPlace(), ctx.device_context(), out); + out->Resize(out_dims); + } else { + out->ShareDataWith(*in); + out->Resize(out_dims); + } } }; @@ -40,10 +46,16 @@ class ReshapeGradKernel : public framework::OpKernel<T> { auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out")); auto* d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X")); d_x->mutable_data<T>(ctx.GetPlace()); + bool inplace = ctx.Attr<bool>("inplace"); auto in_dims = d_x->dims(); - framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x); - d_x->Resize(in_dims); + if (!inplace) { + framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x); + d_x->Resize(in_dims); + } else { + d_x->ShareDataWith(*d_out); + d_x->Resize(in_dims); + } } }; } // namespace operators diff --git a/python/paddle/fluid/tests/unittests/test_reshape_op.py b/python/paddle/fluid/tests/unittests/test_reshape_op.py index 6d1aa549d5..11f35c74d4 100644 --- a/python/paddle/fluid/tests/unittests/test_reshape_op.py +++ b/python/paddle/fluid/tests/unittests/test_reshape_op.py @@ -45,5 +45,33 @@ class TestReshapeOpDimInfer(OpTest): self.check_grad(["X"], "Out") +class TestReshapeOpInplace(OpTest): + def setUp(self): + self.op_type = "reshape" + self.inputs = {'X': np.random.random((10, 20)).astype("float32")} + self.attrs = {'shape': [10 * 20], 'inplace': True} + self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestReshapeOpDimInferInplace(OpTest): + def setUp(self): + self.op_type = "reshape" + self.inputs = {'X': np.random.random((10, 20)).astype("float32")} + self.attrs = {'shape': [4, -1, 5], 'inplace': True} + self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + if __name__ == '__main__': unittest.main() From 8c71adaa8c9eb8debd802aaf3a7166ef3e3718d3 Mon Sep 17 00:00:00 2001 From: pzelazko-intel <pawel.zelazko@intel.com> Date: Wed, 7 Mar 2018 06:40:54 +0100 Subject: [PATCH 11/40] MKLDNN conv2d kernel added (#8451) * MKLDNN conv2 OP kernel added * TODOs added * mkldnn conv2d OP refactor * CanCUDNNBeUsed and CanMKLDNNBeUsed moved --- paddle/fluid/operators/CMakeLists.txt | 26 +- paddle/fluid/operators/conv_mkldnn_op.cc | 313 ++++++++++++++++++ paddle/fluid/operators/conv_op.cc | 51 +-- paddle/fluid/platform/cudnn_helper.h | 14 + paddle/fluid/platform/device_context.cc | 76 ++--- paddle/fluid/platform/device_context.h | 45 +-- paddle/fluid/platform/mkldnn_helper.h | 15 + python/paddle/fluid/layers/nn.py | 4 +- python/paddle/fluid/nets.py | 12 +- .../fluid/tests/unittests/test_conv2d_op.py | 24 +- 10 files changed, 465 insertions(+), 115 deletions(-) create mode 100644 paddle/fluid/operators/conv_mkldnn_op.cc diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 62f00ab612..7e803e3974 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -1,5 +1,7 @@ file(GLOB GENERAL_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc") +string(REPLACE "_mkldnn" "" GENERAL_OPS "${GENERAL_OPS}") string(REPLACE ".cc" "" GENERAL_OPS "${GENERAL_OPS}") +list(REMOVE_DUPLICATES GENERAL_OPS) set(DEPS_OPS "") set(pybind_file ${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/pybind.h) file(WRITE ${pybind_file} "// Generated by the paddle/operator/CMakeLists.txt. DO NOT EDIT!\n\n") @@ -13,6 +15,8 @@ function(op_library TARGET) set(cu_cc_srcs) set(cudnn_cu_cc_srcs) set(CUDNN_FILE) + set(mkldnn_cc_srcs) + set(MKLDNN_FILE) set(op_common_deps operator op_registry math_function) set(options "") set(oneValueArgs "") @@ -36,12 +40,20 @@ function(op_library TARGET) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${CUDNN_FILE}.cu.cc) list(APPEND cudnn_cu_cc_srcs ${CUDNN_FILE}.cu.cc) endif() + if(WITH_MKLDNN) + string(REPLACE "_op" "_mkldnn_op" MKLDNN_FILE "${TARGET}") + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MKLDNN_FILE}.cc) + list(APPEND mkldnn_cc_srcs ${MKLDNN_FILE}.cc) + endif() + endif() else() foreach(src ${op_library_SRCS}) if (${src} MATCHES ".*\\.cu$") list(APPEND cu_srcs ${src}) elseif(${src} MATCHES ".*_cudnn_op.cu.cc$") list(APPEND cudnn_cu_cc_srcs ${src}) + elseif(WITH_MKLDNN AND ${src} MATCHES ".*_mkldnn_op.cc$") + list(APPEND mkldnn_cc_srcs ${src}) elseif(${src} MATCHES ".*\\.cu.cc$") list(APPEND cu_cc_srcs ${src}) elseif(${src} MATCHES ".*\\.cc$") @@ -62,11 +74,11 @@ function(op_library TARGET) set(DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE) endif() if (WITH_GPU) - nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS} + nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${mkldnn_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS} ${op_common_deps}) else() - cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${op_library_DEPS} - ${op_common_deps}) + cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} DEPS ${op_library_DEPS} + ${op_common_deps}) endif() # Define operators that don't need pybind here. @@ -101,7 +113,8 @@ function(op_library TARGET) # pybind USE_CPU_ONLY_OP list(LENGTH cu_srcs cu_srcs_len) list(LENGTH cu_cc_srcs cu_cc_srcs_len) - if (${pybind_flag} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0) + list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len) + if (${pybind_flag} EQUAL 0 AND ${mkldnn_cc_srcs_len} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0) file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n") set(pybind_flag 1) endif() @@ -112,6 +125,11 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n") endif() + # pybind USE_OP_DEVICE_KERNEL for MKLDNN + if (WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0) + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n") + endif() + # pybind USE_OP if (${pybind_flag} EQUAL 0) file(APPEND ${pybind_file} "USE_OP(${TARGET});\n") diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc new file mode 100644 index 0000000000..d59cc2c9d4 --- /dev/null +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -0,0 +1,313 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "mkldnn.hpp" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/conv_op.h" +#include "paddle/fluid/platform/mkldnn_helper.h" + +namespace paddle { +namespace operators { + +using paddle::framework::Tensor; +using paddle::platform::MKLDNNDeviceContext; +using paddle::platform::MKLDNNMemDesc; + +using mkldnn::memory; // Note: paddle has also "memory" namespace +using mkldnn::primitive; +using mkldnn::convolution_forward; +using mkldnn::convolution_backward_weights; +using mkldnn::convolution_backward_data; +using mkldnn::convolution_direct; +using mkldnn::prop_kind; +using mkldnn::padding_kind; +using mkldnn::stream; + +namespace { +std::unique_ptr<mkldnn::convolution_forward::primitive_desc> +ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights, + const memory::desc& dst, const std::vector<int>& strides, + const std::vector<int>& paddings, + const mkldnn::engine& engine); + +convolution_backward_weights::primitive_desc ConvBwdWeightsPrimitiveDesc( + const memory::desc& src, const memory::desc& diff_weights, + const memory::desc& diff_dst, const std::vector<int>& strides, + const std::vector<int>& paddings, + const convolution_forward::primitive_desc& conv_pd, + const mkldnn::engine& engine); + +convolution_backward_data::primitive_desc ConvBwdDataPrimitiveDesc( + const memory::desc& diff_src, const memory::desc& weights, + const memory::desc& diff_dst, const std::vector<int>& strides, + const std::vector<int>& paddings, + const convolution_forward::primitive_desc& conv_pd, + const mkldnn::engine& engine); +} // anonymous namespace + +template <typename T> +class ConvOpMkldnnKernel : public paddle::framework::OpKernel<T> { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + + auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + + auto* input = ctx.Input<Tensor>("Input"); + auto* filter = ctx.Input<Tensor>("Filter"); + auto* output = ctx.Output<Tensor>("Output"); + + // Get an unique name from "argument" name of "Output" variable + // This name will be used as key when saving info into device context + const std::string key = ctx.op().Output("Output"); + const std::string key_conv_pd = key + "@conv_pd"; + + std::vector<int> strides = ctx.Attr<std::vector<int>>("strides"); + std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); + std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations"); + int groups = ctx.Attr<int>("groups"); + + // TODO(pzelazko-intel) add support for group convolution and dilation + PADDLE_ENFORCE(groups == 1, "group convolution is not implemented yet"); + PADDLE_ENFORCE( + dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1, + "dilation in convolution is not implemented yet"); + + const T* input_data = input->data<T>(); + const T* filter_data = filter->data<T>(); + // allocate memory for output + T* output_data = output->mutable_data<T>(ctx.GetPlace()); + + PADDLE_ENFORCE(input->dims().size() == 4, + "Input must be with 4 dimensions, i.e. NCHW"); + PADDLE_ENFORCE(filter->dims().size() == 4, + "Filter must be with 4 dimensions, i.e. OIHW"); + + std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims()); + std::vector<int> weights_tz = + paddle::framework::vectorize2int(filter->dims()); + std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims()); + + // TODO(pzelazko-intel): support more formats + // memory descriptors for convolution src/weight/dst + auto conv_src_md = + MKLDNNMemDesc(src_tz, memory::data_type::f32, memory::format::nchw); + auto conv_weights_md = + MKLDNNMemDesc(weights_tz, memory::data_type::f32, memory::format::oihw); + auto conv_dst_md = + MKLDNNMemDesc(dst_tz, memory::data_type::f32, memory::format::nchw); + + // create memory primitives + auto conv_src_memory = + memory({conv_src_md, mkldnn_engine}, (void*)input_data); + auto conv_weights_memory = + memory({conv_weights_md, mkldnn_engine}, (void*)filter_data); + auto conv_dst_memory = memory({conv_dst_md, mkldnn_engine}, output_data); + + std::unique_ptr<convolution_forward::primitive_desc> conv_pd = + ConvFwdPrimitiveDesc(conv_src_md, conv_weights_md, conv_dst_md, strides, + paddings, mkldnn_engine); + + // save p_conv_pd into dev_ctx to be referred in backward path + auto p_conv_pd = conv_pd.get(); + std::shared_ptr<void> conv_pd_value = std::move(conv_pd); + dev_ctx.SetBlob(key_conv_pd, conv_pd_value); + + // create convolution op primitive + auto conv_prim = convolution_forward(*p_conv_pd, conv_src_memory, + conv_weights_memory, conv_dst_memory); + + // push op to stream and wait MKLDNN until it's executed + std::vector<primitive> pipeline{conv_prim}; + stream(stream::kind::eager).submit(pipeline).wait(); + } +}; + +template <typename T> +class ConvGradOpMkldnnKernel : public paddle::framework::OpKernel<T> { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + + auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + + const Tensor* input = ctx.Input<Tensor>("Input"); + const Tensor* filter = ctx.Input<Tensor>("Filter"); + const Tensor* output = ctx.Input<Tensor>("Output"); + const Tensor* output_grad = + ctx.Input<Tensor>(framework::GradVarName("Output")); + Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input")); + Tensor* filter_grad = ctx.Output<Tensor>(framework::GradVarName("Filter")); + + if (!input_grad && !filter_grad) return; + + // Get an unique name from "argument" name of "Output" variable + // This name will be used as key when saving info into device context + const std::string key = ctx.op().Input("Output"); + const std::string key_conv_pd = key + "@conv_pd"; + + std::vector<int> strides = ctx.Attr<std::vector<int>>("strides"); + std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); + + const T* input_data = input->data<T>(); + const T* filter_data = filter->data<T>(); + const T* output_grad_data = output_grad->data<T>(); + T* input_grad_data = nullptr; + T* filter_grad_data = nullptr; + + // allocate memory for gradient of input/filter + if (input_grad) { + input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace()); + } + if (filter_grad) { + filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace()); + } + + std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims()); + std::vector<int> weights_tz = + paddle::framework::vectorize2int(filter->dims()); + std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims()); + + // TODO(pzelazko-intel): support more formats + auto conv_src_md = + MKLDNNMemDesc(src_tz, memory::data_type::f32, memory::format::nchw); + auto conv_diff_src_md = + MKLDNNMemDesc(src_tz, memory::data_type::f32, memory::format::nchw); + auto conv_weights_md = + MKLDNNMemDesc(weights_tz, memory::data_type::f32, memory::format::oihw); + auto conv_diff_weights_md = + MKLDNNMemDesc(weights_tz, memory::data_type::f32, memory::format::oihw); + auto conv_diff_dst_md = + MKLDNNMemDesc(dst_tz, memory::data_type::f32, memory::format::nchw); + + // create memory + auto conv_diff_dst_memory = + memory({conv_diff_weights_md, mkldnn_engine}, (void*)output_grad_data); + // Retrieve conv_pd from device context + std::shared_ptr<void> conv_pd; + convolution_forward::primitive_desc* p_conv_pd; + + conv_pd = dev_ctx.GetBlob(key_conv_pd); + PADDLE_ENFORCE(conv_pd != nullptr, + "Fail to find conv_pd in device context"); + p_conv_pd = + static_cast<convolution_forward::primitive_desc*>(conv_pd.get()); + + // create backward conv primitive for weights + if (filter_grad) { + // create primitive descriptor + convolution_backward_weights::primitive_desc conv_bwd_weights_pd = + ConvBwdWeightsPrimitiveDesc(conv_src_md, conv_diff_weights_md, + conv_diff_dst_md, strides, paddings, + *p_conv_pd, mkldnn_engine); + + // create memory + auto conv_diff_weights_memory = memory( + {conv_diff_weights_md, mkldnn_engine}, (void*)filter_grad_data); + auto conv_src_memory = + memory({conv_src_md, mkldnn_engine}, (void*)input_data); + + // create backward conv primitive for weights + auto conv_bwd_weights_prim = convolution_backward_weights( + conv_bwd_weights_pd, conv_src_memory, conv_diff_dst_memory, + conv_diff_weights_memory); + + // push primitive and execute it + std::vector<primitive> pipeline{conv_bwd_weights_prim}; + stream(stream::kind::eager).submit(pipeline).wait(); + } + + if (input_grad) { + // create primitive descriptor + convolution_backward_data::primitive_desc conv_bwd_data_pd = + ConvBwdDataPrimitiveDesc(conv_diff_src_md, conv_weights_md, + conv_diff_dst_md, strides, paddings, + *p_conv_pd, mkldnn_engine); + + // create memory + auto conv_diff_src_memory = + memory({conv_diff_src_md, mkldnn_engine}, (void*)input_grad_data); + auto conv_weights_memory = + memory({conv_weights_md, mkldnn_engine}, (void*)filter_data); + + // create backward conv primitive for data + auto conv_bwd_data_prim = + convolution_backward_data(conv_bwd_data_pd, conv_diff_dst_memory, + conv_weights_memory, conv_diff_src_memory); + + // push primitive and execute it + std::vector<primitive> pipeline{conv_bwd_data_prim}; + stream(stream::kind::eager).submit(pipeline).wait(); + } + } // Compute() +}; + +namespace { +std::unique_ptr<convolution_forward::primitive_desc> ConvFwdPrimitiveDesc( + const memory::desc& src, const memory::desc& weights, + const memory::desc& dst, const std::vector<int>& strides, + const std::vector<int>& paddings, const mkldnn::engine& engine) { + mkldnn::memory::dims stride_dims = {strides[0], strides[1]}; + mkldnn::memory::dims padding_dims = {paddings[0], paddings[1]}; + + auto conv_desc = mkldnn::convolution_forward::desc( + mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights, dst, + stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); + + auto p_conv_pd = new convolution_forward::primitive_desc(conv_desc, engine); + + return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>( + p_conv_pd); +} + +convolution_backward_weights::primitive_desc ConvBwdWeightsPrimitiveDesc( + const memory::desc& src, const memory::desc& diff_weights, + const memory::desc& diff_dst, const std::vector<int>& strides, + const std::vector<int>& paddings, + const convolution_forward::primitive_desc& conv_pd, + const mkldnn::engine& engine) { + auto conv_bwd_weights_desc = convolution_backward_weights::desc( + convolution_direct, src, diff_weights, diff_dst, strides, paddings, + paddings, padding_kind::zero); + return convolution_backward_weights::primitive_desc(conv_bwd_weights_desc, + engine, conv_pd); +} + +convolution_backward_data::primitive_desc ConvBwdDataPrimitiveDesc( + const memory::desc& diff_src, const memory::desc& weights, + const memory::desc& diff_dst, const std::vector<int>& strides, + const std::vector<int>& paddings, + const convolution_forward::primitive_desc& conv_pd, + const mkldnn::engine& engine) { + auto conv_bwd_data_desc = convolution_backward_data::desc( + convolution_direct, diff_src, weights, diff_dst, strides, paddings, + paddings, padding_kind::zero); + return convolution_backward_data::primitive_desc(conv_bwd_data_desc, engine, + conv_pd); +} +} // anonymous namespace +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_KERNEL(conv2d, MKLDNN, ::paddle::platform::CPUPlace, + ops::ConvOpMkldnnKernel<float>); + +REGISTER_OP_KERNEL(conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace, + ops::ConvGradOpMkldnnKernel<float>); diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 83b7708bf3..4b02b80d77 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -13,6 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/conv_op.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cudnn_helper.h" +#endif +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif namespace paddle { namespace operators { @@ -64,22 +70,21 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType ConvOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - bool use_cudnn = ctx.Attr<bool>("use_cudnn"); - use_cudnn &= platform::is_gpu_place(ctx.GetPlace()); + framework::LibraryType library_{framework::LibraryType::kPlain}; #ifdef PADDLE_WITH_CUDA - if (platform::is_gpu_place(ctx.GetPlace())) { - auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); - use_cudnn &= dev_ctx.cudnn_handle() != nullptr; + if (platform::CanCUDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kCUDNN; } #endif - framework::LibraryType library_; - if (use_cudnn) { - library_ = framework::LibraryType::kCUDNN; - } else { - library_ = framework::LibraryType::kPlain; +#ifdef PADDLE_WITH_MKLDNN + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; } +#endif std::string data_format = ctx.Attr<std::string>("data_format"); + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready framework::DataLayout layout_ = framework::StringToDataLayout(data_format); return framework::OpKernelType( framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(), @@ -131,6 +136,9 @@ Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker) "use_cudnn", "(bool, default false) Only used in cudnn kernel, need install cudnn") .SetDefault(false); + AddAttr<bool>("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); AddAttr<std::string>( "data_format", "(string, default NCHW) Only used in " @@ -224,6 +232,9 @@ Conv3DOpMaker::Conv3DOpMaker(OpProto* proto, OpAttrChecker* op_checker) "use_cudnn", "(bool, default false) Only used in cudnn kernel, need install cudnn") .SetDefault(false); + AddAttr<bool>("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); AddAttr<std::string>( "data_format", "(string, default NCHW) Only used in " @@ -284,23 +295,21 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType ConvOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - bool use_cudnn = ctx.Attr<bool>("use_cudnn"); - use_cudnn &= platform::is_gpu_place(ctx.GetPlace()); + framework::LibraryType library_{framework::LibraryType::kPlain}; #ifdef PADDLE_WITH_CUDA - if (platform::is_gpu_place(ctx.GetPlace())) { - auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); - use_cudnn &= dev_ctx.cudnn_handle() != nullptr; + if (platform::CanCUDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kCUDNN; } #endif - - framework::LibraryType library_; - if (use_cudnn) { - library_ = framework::LibraryType::kCUDNN; - } else { - library_ = framework::LibraryType::kPlain; +#ifdef PADDLE_WITH_MKLDNN + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; } +#endif std::string data_format = ctx.Attr<std::string>("data_format"); + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready framework::DataLayout layout_ = framework::StringToDataLayout(data_format); return framework::OpKernelType( framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(), diff --git a/paddle/fluid/platform/cudnn_helper.h b/paddle/fluid/platform/cudnn_helper.h index 48c967de11..1842ecd745 100644 --- a/paddle/fluid/platform/cudnn_helper.h +++ b/paddle/fluid/platform/cudnn_helper.h @@ -15,6 +15,8 @@ limitations under the License. */ #pragma once #include <vector> + +#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/dynload/cudnn.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/macros.h" @@ -282,5 +284,17 @@ class ScopedPoolingDescriptor { DISABLE_COPY_AND_ASSIGN(ScopedPoolingDescriptor); }; +inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) { + bool use_cudnn = ctx.Attr<bool>("use_cudnn"); + use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace()); +#ifdef PADDLE_WITH_CUDA + if (use_cudnn) { + auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); + use_cudnn &= dev_ctx.cudnn_handle() != nullptr; + } +#endif + return use_cudnn; +} + } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 7da6e04d0a..326ff67ab9 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -33,9 +33,15 @@ DeviceContextPool::DeviceContextPool( PADDLE_ENFORCE_GT(places.size(), 0); for (size_t i = 0; i < places.size(); i++) { if (platform::is_cpu_place(places[i])) { +#ifdef PADDLE_WITH_MKLDNN + device_contexts_.emplace(places[i], + new platform::MKLDNNDeviceContext( + boost::get<platform::CPUPlace>(places[i]))); +#else device_contexts_.emplace(places[i], new platform::CPUDeviceContext( boost::get<platform::CPUPlace>(places[i]))); +#endif } else if (platform::is_gpu_place(places[i])) { #ifdef PADDLE_WITH_CUDA device_contexts_.emplace(places[i], @@ -170,64 +176,38 @@ cudaStream_t CUDADeviceContext::stream() const { return stream_; } #ifdef PADDLE_WITH_MKLDNN MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place) - : CPUDeviceContext(place), ready_(false) { - stream_.reset(new mkldnn::stream(mkldnn::stream::kind::eager)); - engine_.reset(new mkldnn::engine(mkldnn::engine::cpu, 0)); + : CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobs_() { + p_blobs_.reset(new std::unordered_map<std::string, std::shared_ptr<void>>()); } -template <typename T> -void MKLDNNDeviceContext::AddElement(const std::string& op_key, - const T& value) { - if (GetElement<T>(op_key)) { - return; - } - GetElementPool<T>().emplace(op_key, std::move(value)); -} +void MKLDNNDeviceContext::SetBlob(const std::string& name, + std::shared_ptr<void> data) const { + std::unordered_map<std::string, std::shared_ptr<void>>* p; + p = p_blobs_.get(); -template <typename T> -const T& MKLDNNDeviceContext::GetElement(const std::string& op_key) const { - auto it = GetElementPool<T>().find(op_key); - return it == GetElementPool<T>().end() ? nullptr : it->second; -} + auto it = p->find(name); -template <> -const std::unordered_map<const std::string, const MKLDNNMemoryPtr, - std::hash<std::string>>& -MKLDNNDeviceContext::GetElementPool<MKLDNNMemoryPtr>() const { - return memory_pool_; -} + if (it == p->end()) { + (*p)[name] = data; // create new blob + } else { + it->second = data; // set data to existing blob + } -template <> -const std::unordered_map<const std::string, const MKLDNNPrimitivePtr, - std::hash<std::string>>& -MKLDNNDeviceContext::GetElementPool<MKLDNNPrimitivePtr>() const { - return primitive_pool_; + return; } -template <> -const std::unordered_map<const std::string, const MKLDNNPrimitiveDescPtr, - std::hash<std::string>>& -MKLDNNDeviceContext::GetElementPool<MKLDNNPrimitiveDescPtr>() const { - return primitive_desc_pool_; -} +std::shared_ptr<void> MKLDNNDeviceContext::GetBlob( + const std::string& name) const { + std::unordered_map<std::string, std::shared_ptr<void>>* p; + p = p_blobs_.get(); -void MKLDNNDeviceContext::Execute(bool block) { - if (pipeline_.empty()) { - return; - } - ResetStream(); - stream_->submit(pipeline_).wait(block); - ready_ = false; - pipeline_.clear(); -} + auto it = p->find(name); -void MKLDNNDeviceContext::ResetStream() { - if (ready_) { - return; + if (it != p->end()) { + return it->second; } - // TODO(TJ): change me when mkldnn have specific method to reset this state - stream_.reset(new mkldnn::stream(mkldnn::stream::kind::eager)); - ready_ = true; + + return nullptr; } #endif diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index a294ba5101..01de8c4ab3 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -22,7 +22,7 @@ limitations under the License. */ #endif #ifdef PADDLE_WITH_MKLDNN -#include "paddle/fluid/platform/mkldnn_helper.h" +#include <mkldnn.hpp> #endif #include "paddle/fluid/platform/enforce.h" @@ -114,46 +114,19 @@ class MKLDNNDeviceContext : public CPUDeviceContext { public: explicit MKLDNNDeviceContext(CPUPlace place); - /* \brief Add new element: memory, primitive or primitive desc */ - template <typename T> - void AddElement(const std::string& op_key, const T& value); - - /* \brief Get existed element: memory, primitive or primitive desc */ - template <typename T> - const T& GetElement(const std::string& op_key) const; - - /* \brief Get element pool: memory, primitive or primitive desc pool */ - template <typename T> - const std::unordered_map<const std::string, const T, std::hash<std::string>>& - GetElementPool() const; - /* \brief Get the active engine */ - const MKLDNNEngine& engine() const { return *engine_; } - - /* \brief Submit primitive to pipeline */ - void Submit(const MKLDNNPrimitivePtr& p) { pipeline_.push_back(*p); } + const mkldnn::engine& GetEngine() const { return engine_; } - /*! \brief Execute all submitted primitives in pipeline */ - void Execute(bool block = true); + // Set data to blob (i.e. name/data pair). Create blob if not existing + void SetBlob(const std::string& name, std::shared_ptr<void> data) const; - protected: - /*! \brief Reset the stream to prepare next exectue */ - void ResetStream(); + // Find a saved blob. Return nullptr if not found + std::shared_ptr<void> GetBlob(const std::string& name) const; private: - std::unordered_map<const std::string, const MKLDNNMemoryPtr, - std::hash<std::string>> - memory_pool_; - std::unordered_map<const std::string, const MKLDNNPrimitivePtr, - std::hash<std::string>> - primitive_pool_; - std::unordered_map<const std::string, const MKLDNNPrimitiveDescPtr, - std::hash<std::string>> - primitive_desc_pool_; - std::vector<MKLDNNPrimitive> pipeline_; - MKLDNNStreamPtr stream_; - MKLDNNEnginePtr engine_; - bool ready_; + mkldnn::engine engine_; + std::shared_ptr<std::unordered_map<std::string, std::shared_ptr<void>>> + p_blobs_; }; #endif diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 6d71f352c6..90b78142b8 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -16,12 +16,15 @@ limitations under the License. */ #include <mkldnn.hpp> +#include "paddle/fluid/framework/operator.h" + namespace paddle { namespace platform { using MKLDNNStream = mkldnn::stream; using MKLDNNEngine = mkldnn::engine; using MKLDNNMemory = mkldnn::memory; +using MKLDNNMemoryDescriptor = mkldnn::memory::desc; using MKLDNNPrimitive = mkldnn::primitive; using MKLDNNPrimitiveDesc = mkldnn::handle<mkldnn_primitive_desc_t>; @@ -31,5 +34,17 @@ typedef std::unique_ptr<MKLDNNMemory> MKLDNNMemoryPtr; typedef std::unique_ptr<MKLDNNPrimitive> MKLDNNPrimitivePtr; typedef std::unique_ptr<MKLDNNPrimitiveDesc> MKLDNNPrimitiveDescPtr; +inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int>& dims, + mkldnn::memory::data_type data_type, + mkldnn::memory::format format) { + mkldnn::memory::dims tz = dims; + return mkldnn::memory::desc({tz}, data_type, format); +} + +inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) { + bool use_mkldnn = ctx.Attr<bool>("use_mkldnn"); + return use_mkldnn && platform::is_cpu_place(ctx.GetPlace()); +} + } // namespace platform } // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index a0842c57ee..b4fa530aa6 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -1111,6 +1111,7 @@ def conv2d(input, param_attr=None, bias_attr=None, use_cudnn=True, + use_mkldnn=False, act=None): """ **Convlution2D Layer** @@ -1252,7 +1253,8 @@ def conv2d(input, 'strides': stride, 'paddings': padding, 'groups': groups, - 'use_cudnn': use_cudnn + 'use_cudnn': use_cudnn, + 'use_mkldnn': use_mkldnn }) pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2) diff --git a/python/paddle/fluid/nets.py b/python/paddle/fluid/nets.py index c161d93854..8c627ad55b 100644 --- a/python/paddle/fluid/nets.py +++ b/python/paddle/fluid/nets.py @@ -29,14 +29,16 @@ def simple_img_conv_pool(input, act, param_attr=None, pool_type='max', - use_cudnn=True): + use_cudnn=True, + use_mkldnn=False): conv_out = layers.conv2d( input=input, num_filters=num_filters, filter_size=filter_size, param_attr=param_attr, act=act, - use_cudnn=use_cudnn) + use_cudnn=use_cudnn, + use_mkldnn=use_mkldnn) pool_out = layers.pool2d( input=conv_out, @@ -58,7 +60,8 @@ def img_conv_group(input, conv_batchnorm_drop_rate=0.0, pool_stride=1, pool_type=None, - use_cudnn=True): + use_cudnn=True, + use_mkldnn=False): """ Image Convolution Group, Used for vgg net. """ @@ -90,7 +93,8 @@ def img_conv_group(input, padding=conv_padding[i], param_attr=param_attr[i], act=local_conv_act, - use_cudnn=use_cudnn) + use_cudnn=use_cudnn, + use_mkldnn=use_mkldnn) if conv_with_batchnorm[i]: tmp = layers.batch_norm(input=tmp, act=conv_act) diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_op.py index 1321cfd484..a49fecf095 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_op.py @@ -64,6 +64,7 @@ def conv2d_forward_naive(input, filter, group, conv_param): class TestConv2dOp(OpTest): def setUp(self): self.use_cudnn = False + self.use_mkldnn = False self.init_op_type() self.init_group() self.init_dilation() @@ -85,7 +86,8 @@ class TestConv2dOp(OpTest): 'paddings': self.pad, 'groups': self.groups, 'dilations': self.dilations, - 'use_cudnn': self.use_cudnn + 'use_cudnn': self.use_cudnn, + 'use_mkldnn': self.use_mkldnn } self.outputs = {'Output': output} @@ -290,5 +292,25 @@ class TestDepthwiseConv2(TestConv2dOp): # def init_op_type(self): # self.op_type = "conv_cudnn" + +#----------------Conv2dMKLDNN---------------- +class TestMKLDNN(TestConv2dOp): + def init_op_type(self): + self.use_mkldnn = True + self.op_type = "conv2d" + + +class TestMKLDNNWithPad(TestWithPad): + def init_op_type(self): + self.use_mkldnn = True + self.op_type = "conv2d" + + +class TestMKLDNNWithStride(TestWithStride): + def init_op_type(self): + self.use_mkldnn = True + self.op_type = "conv2d" + + if __name__ == '__main__': unittest.main() From 142fac18ec41abe570147c44e6b434f807efae88 Mon Sep 17 00:00:00 2001 From: qiaolongfei <qiaolongfei@baidu.com> Date: Wed, 7 Mar 2018 16:27:38 +0800 Subject: [PATCH 12/40] add print_log to memory_optimize --- python/paddle/fluid/memory_optimization_transpiler.py | 8 ++++++-- .../book_memory_optimization/test_memopt_fit_a_line.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/memory_optimization_transpiler.py b/python/paddle/fluid/memory_optimization_transpiler.py index 708ca08b17..e82456a99f 100644 --- a/python/paddle/fluid/memory_optimization_transpiler.py +++ b/python/paddle/fluid/memory_optimization_transpiler.py @@ -31,6 +31,8 @@ dtype_to_size = { sub_block_ops = ["while", "while_grad", "parallel_do", "parallel_do_grad"] +PRINT_LOG = False + class ControlFlowGraph(object): def __init__(self, Program, ops, forward_num, skip_opt): @@ -170,7 +172,7 @@ class ControlFlowGraph(object): block_desc, cache_var, is_forward).dtype() # TODO(qijun): actually, we should compare dtype_to_size[x_dtype] # and dtype_to_size[cache_dtype] - if x_dtype == cache_dtype: + if x_dtype == cache_dtype and PRINT_LOG: print(("Hit Cache !!!! cache pool index " "is %d, var name is %s, " "cached var name is %s, " @@ -277,7 +279,9 @@ def _get_cfgs(input_program): return cfgs -def memory_optimize(input_program): +def memory_optimize(input_program, print_log=False): + global PRINT_LOG + PRINT_LOG = print_log cfgs = _get_cfgs(input_program) for cfg in cfgs: cfg.memory_optimize() diff --git a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py index 7648bb9fe1..c9d2a5ecaa 100644 --- a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py +++ b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py @@ -49,7 +49,7 @@ avg_cost = fluid.layers.mean(x=cost) sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01) sgd_optimizer.minimize(avg_cost) -fluid.memory_optimize(fluid.default_main_program()) +fluid.memory_optimize(fluid.default_main_program(), print_log=True) BATCH_SIZE = 200 From fe2d590d2102d97f95b1bdebd863b1d9cb34feac Mon Sep 17 00:00:00 2001 From: qiaolongfei <qiaolongfei@baidu.com> Date: Wed, 7 Mar 2018 16:33:19 +0800 Subject: [PATCH 13/40] fix bug --- .../fluid/memory_optimization_transpiler.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/memory_optimization_transpiler.py b/python/paddle/fluid/memory_optimization_transpiler.py index e82456a99f..4fa2d03ef5 100644 --- a/python/paddle/fluid/memory_optimization_transpiler.py +++ b/python/paddle/fluid/memory_optimization_transpiler.py @@ -172,13 +172,15 @@ class ControlFlowGraph(object): block_desc, cache_var, is_forward).dtype() # TODO(qijun): actually, we should compare dtype_to_size[x_dtype] # and dtype_to_size[cache_dtype] - if x_dtype == cache_dtype and PRINT_LOG: - print(("Hit Cache !!!! cache pool index " - "is %d, var name is %s, " - "cached var name is %s, " - "var shape is %s ") % - (index, x, cache_var, - str(cache_shape))) + if x_dtype == cache_dtype: + if PRINT_LOG: + print( + ("Hit Cache !!!! cache pool index " + "is %d, var name is %s, " + "cached var name is %s, " + "var shape is %s ") % + (index, x, cache_var, + str(cache_shape))) self.pool.pop(index) if x == cache_var: break From 3ddc9971823746fa18b3ca2cb80f851a8dd94cf5 Mon Sep 17 00:00:00 2001 From: Luo Tao <luotao02@baidu.com> Date: Wed, 7 Mar 2018 18:08:00 +0800 Subject: [PATCH 14/40] rename concat_functor to concat, refine CMakeLists based on comments --- paddle/fluid/operators/CMakeLists.txt | 2 +- paddle/fluid/operators/math/CMakeLists.txt | 40 +++++++++---------- paddle/fluid/operators/math/sequence2batch.cc | 1 - 3 files changed, 20 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 3cbbbcb328..5d436a7e0c 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -201,7 +201,7 @@ op_library(save_op DEPS lod_tensor) op_library(load_op DEPS lod_tensor) op_library(save_combine_op DEPS lod_tensor) op_library(load_combine_op DEPS lod_tensor) -op_library(concat_op DEPS concat_functor) +op_library(concat_op DEPS concat) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index e88f4ed1dc..11bc176400 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -6,8 +6,8 @@ function(math_library TARGET) # But it handle split GPU/CPU code and link some common library. set(cc_srcs) set(cu_srcs) - set(math_common_deps device_context framework_proto) - set(multiValueArgs SRCS DEPS) + set(math_common_deps device_context) + set(multiValueArgs DEPS) cmake_parse_arguments(math_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) @@ -18,36 +18,34 @@ function(math_library TARGET) list(APPEND cu_srcs ${TARGET}.cu) endif() + list(LENGTH cc_srcs cc_srcs_len) if (WITH_GPU) nv_library(${TARGET} SRCS ${cc_srcs} ${cu_srcs} DEPS ${math_library_DEPS} ${math_common_deps}) - else() + elseif(${cc_srcs_len} GREATER 0) cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${math_library_DEPS} ${math_common_deps}) endif() endfunction() -math_library(math_function DEPS cblas) -math_library(im2col) -math_library(selected_rows_functor DEPS selected_rows) -math_library(softmax) +# please add new math_library in alphabetical order +math_library(concat) +math_library(context_project DEPS im2col math_function) math_library(cross_entropy) +math_library(cos_sim_functor) +math_library(depthwise_conv) +math_library(gru_compute DEPS activation_functions math_function) +math_library(im2col) +math_library(lstm_compute DEPS activation_functions) +math_library(math_function DEPS cblas framework_proto) +math_library(maxouting) math_library(pooling) -math_library(sequence_pooling) -math_library(vol2col) -math_library(context_project) +math_library(selected_rows_functor DEPS selected_rows) math_library(sequence2batch) math_library(sequence_padding) +math_library(sequence_pooling DEPS math_function) math_library(sequence_scale) -math_library(maxouting) +math_library(softmax) math_library(unpooling) -math_library(cos_sim_functor) -math_library(lstm_compute DEPS activation_functions) -math_library(gru_compute DEPS activation_functions) -if(WITH_GPU) - nv_library(depthwise_conv SRCS depthwise_conv.cu DEPS device_context) - nv_library(concat_functor SRCS concat.cc concat.cu DEPS device_context tensor) -else() - cc_library(concat_functor SRCS concat.cc DEPS device_context tensor) -endif() +math_library(vol2col) cc_test(math_function_test SRCS math_function_test.cc) cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor) @@ -58,4 +56,4 @@ if(WITH_GPU) nv_test(math_function_gpu_test SRCS math_function_test.cu) nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor) endif() -cc_test(concat_test SRCS concat_test.cc DEPS concat_functor tensor) +cc_test(concat_test SRCS concat_test.cc DEPS concat) diff --git a/paddle/fluid/operators/math/sequence2batch.cc b/paddle/fluid/operators/math/sequence2batch.cc index 72bf2ab170..8899abff36 100644 --- a/paddle/fluid/operators/math/sequence2batch.cc +++ b/paddle/fluid/operators/math/sequence2batch.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/sequence2batch.h" -#include "paddle/fluid/operators/math/math_function.h" namespace paddle { namespace operators { From 49f3f1db0796c8cd06caed5a8de3ba11a68974a3 Mon Sep 17 00:00:00 2001 From: Luo Tao <luotao02@baidu.com> Date: Wed, 7 Mar 2018 18:35:28 +0800 Subject: [PATCH 15/40] add back framework_proto depends --- paddle/fluid/operators/math/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 11bc176400..a181d80226 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -6,7 +6,7 @@ function(math_library TARGET) # But it handle split GPU/CPU code and link some common library. set(cc_srcs) set(cu_srcs) - set(math_common_deps device_context) + set(math_common_deps device_context framework_proto) set(multiValueArgs DEPS) cmake_parse_arguments(math_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) @@ -35,7 +35,7 @@ math_library(depthwise_conv) math_library(gru_compute DEPS activation_functions math_function) math_library(im2col) math_library(lstm_compute DEPS activation_functions) -math_library(math_function DEPS cblas framework_proto) +math_library(math_function DEPS cblas) math_library(maxouting) math_library(pooling) math_library(selected_rows_functor DEPS selected_rows) From f8e0c41e0e1ccb96781c00eb3fe1974021a493f5 Mon Sep 17 00:00:00 2001 From: qiaolongfei <qiaolongfei@baidu.com> Date: Wed, 7 Mar 2018 20:06:22 +0800 Subject: [PATCH 16/40] add timeline profile howto --- doc/fluid/howto/optimization/timeline.jpeg | Bin 0 -> 70606 bytes doc/fluid/howto/optimization/timeline.md | 27 +++++++++++++++++++++ doc/fluid/howto/optimization/tracing.jpeg | Bin 0 -> 30668 bytes 3 files changed, 27 insertions(+) create mode 100644 doc/fluid/howto/optimization/timeline.jpeg create mode 100644 doc/fluid/howto/optimization/timeline.md create mode 100644 doc/fluid/howto/optimization/tracing.jpeg diff --git a/doc/fluid/howto/optimization/timeline.jpeg b/doc/fluid/howto/optimization/timeline.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..38ec3f80c982857531f30a8bb0fa26ea5bf05385 GIT binary patch literal 70606 zcmeFZcUTnLwl7*uH#z4_lTi?m3=&#Il87KmP9h*la!w5*L4qKtpr9ZiAXy~m3<4ra zGLjKV(hUj?G~KUo?X~tk``r7^x$nL|?)|<sbyod))R<$AF=nVSXH|ip!Y=|e*EF;> z00;yCo`C-V{1V`!9^iBj0CaT$VE_O~02qV~AOsP}0;oXP|H2xOa{%-=od5u0oB-HA zcy55_vjj@}UFMH7!5e}<IS6^*K>x%Lr88Fi4j`}Z=<eh0?da~oEhTmlkiVj>3p<kt zrr%iLH>CN(nJ}FMe1L6)lfKx-TtVWOsqbj3safkA=xJzQSO0AYwZ4sq$9)1C0C08l z@iNf7!fj$|#!b8oYC{B21N;D^jjgwbs=og9Gns#$f6M>ta4_*(cVJBPOx9oI{~n;V zv-h?Ijn@a+@7j6Wx`Mb00ARAV9$r2GK)4Ok`TTu6KuiQ`2{3qp3WAvb3_JXR<<Icl zKk%8J3|_Yl)Igp)V3SbT*myYr0M(gZ+<vwWpd6|M5MOk&b8-W56Npu89c}DDj0N!p zS66osQ=bFzxx4=&$KhYFjm_P^WZKv`{DuFp1uO|xeAUU@<GxMc?;rnv9$fGHg8urQ z_`oBDqnDNecxE{Bn~S^78Kwa7N4L8-^gv7p;xA61MgPFx9DTJ6L43xKvGrC1eF-oD z00D)a%~c%`bAXuF?f#83Kc3N5?m1~_gBbKZ!EGlW?b{#*%M#dmxf}dLK0$!JxB4xR z4%S5wVB-a%KX~GNT?|!0d?qu~(O&ycnGHdXw~awMs3*a&pOewqD4+uIoVUxZGab(A z-3)S6JEQ;R@$xk|(-~}YC~(h9<IF~o2g>j24Wi%mLN7TwX`k_fIzX@cINsI*F<2IA zWoz?WHx>{(+k2aw>F`@mJNv8ZAO`J%M%cR<{zHFgiigjYGx=xyZ$0k+?)PUho9ypv zp3!+h{Mp;@)}LijK4519>7Y+wfQOI%nSVgtU<}SS*UtJhr~^zGxDKcRR{?Io7tD77 zcfc8N`n+Ac4gUJ$iUwc<cmWQ8Js|oI%AX~Se_wF}kGFszU<YsmdA$EBcjfn0d%zE* zpZ-PtqpT=!@Ap-|-z8jtZ(t3sfDUjU{Ot>3d$7!(wHX1cAjRpg-+$I_3+j3w%xBmB z6OEhtPa2Ur(N!WXA`PPRFgBP9Oc8b+JYEK~BupNr@TXq?<R|<_I6}BgI8L}oIBo3| z;q@oSpYi}*U=h^u53ksR9{U&TV1l4j8Q2w=GAJLE4P%0_0irNzuqIU)0<2pGr2ea} zzJFQ#PhbAS@1z$=2}qeqFO#18cMYh7sbv0=a^@*+;Ge$uL$AAkc}wjtFa6n`f7buc zD^7qN$ffh|oX|(mN@y>%9oh$Ng4P1u&{k+2^b@rC4F98kjo)qK@K<dHzgx})w94r( zo<Hle0V!v7p0%#`#o~(_7qS140{GYm_<;QpP;>VP^m1}=^x;+oXKQ<IZ8ux73*3?y zCBRwqZ0<kf0f2p`-*Y;IssEoe&o%&1hy&-1vuT(6j0*s4z&WV?8vxKB{FBDF2X^}> z0MI;a>+9wB2OsS0L<o?9_Y?+z72pE+fODWll7K9r2wVo#fop&sa2qfMtUy0Ig8kGB z@CQPG2S79s58fT01L;5(kPEy6N`VhREzktC13kb1Fbs?XvtVCY1GWGZZ~&Y@AP`~* zC4?Tr3V}leAR-V+h&)6Eq7JzZF@Tst?m!$MZV+Ec2;?Co9+Cn{hh#$vAr+8%NIRqt zG6I=~EI~FPdyr!SC;<fl0|6(2Ab~i6Jb@~K4#8~#O9BT1Pl6zVNP<Ly7XB?L7D z?F0h^;{=NYTLizL0F)fc1m%H>Lgk@qP(7#_^d8g;iiE~NQ=vJ~GH3(X(qqs?=r;5S zMg*gWal=Hx{;CBtgxSJ8VMtg4EFJa^Rs-vTjlq^+C>V~Al8}Q?gzyrfHlYcj1ED`* zG+`=XK4CRsH#m0I2oH#eh?t0kh~&VrXG-Ku6ik#r^qQ!gsEufZXoct(F)=YKu?Vp; zu^zDvu@7-H@k`<o;#T5O;x%Fn2{{RzM3O{<#Eis^B!VQBq==-2WRzr`<b;%lRDkpn zsUE2vX&`A5>08nU(jn3{(i1XTG9fZ$vfE_NWDm$*kd={jlg*LslarJ4k}HtkBzGbY zCx1cyp8O;E68RAY9fdH38ifUgFGV6n0Yw|dG{rt8C8Z$cWlA$jAIe0^Lds6cc}ffw zJ(U=hHkBP!7*z&U4b>>s4{B0s0curhE9xNXRO(9VA?j@!QW`-TH5wZlBuxfQJ<TM| zFIsw930gf`SK0*HBHBLM?{vg;f^_P1c65<+xpbX$EA&u$K6*8JTl$CeZ|S?~*BFQx zgc!6LoEhR7N*O*gpcv^Gr5TMF0~j+Hn;93GpiF{H+D!MEl9)a)O)woZb26(k+cU>9 zmobkrAF{BsTxPLniDM~e8D}|W<z`i9bzx0rt!4en24fRuyUFIymdV!5w$0AKuE=i7 z9>@NHeTIX8Lzu&WBbXzX;}gd(PA*POP7ltPoSmH8TufY-xtzJ4akX%5!0F*Ca3}aP zcq@F9n~__U`#$#z?k;W=4+oDHk1tO)&u5+!ULjsX-U!}u-miRQeDZwv_@41~@S*tO z{JQ)h{Kfn;0;B@+0*(S{0=)u<f<l7Ef-!=%g5QN$gtUc%go=gc&QYF2obx=FbMDJ| zqVw|SUCzHgKP*fjEF<hFoFP0ULLee5;w17~WaI+j1%(T47jiC4ic*NGiu#Ebi!O;V ziCq_q5UUmYA<i#uCY~(bD~`J;bMgMgw-;w6=p?R5L`XD9>`4ks+Dc|fevzV-(vS+1 zs*^%Xi%8o`XG%}W(97t`#K?5WoXX0{ddZf_ZpiV=-I04GH!05`uP>h<|51TZ;fg|- zLX*PLCD}_pmp)wDRTNcpQ7lsYt|X*nuk==FMVVLGMmbw~QH5LOj>;RAMFbDR29bkU zzRY*o?sER+b=C8#&Z;G<KdxN7;&r9^%AwjNwNSOTtI(^OSL3dJR;N=pR8Lo**WlG~ z&?weGY078@Yqn_-YUyYtYfWl%YTIfTYNM{nT?@U|r9+`}ODA1t`MU6Ruj>uE1iIIB zpX$!t5V+xXqgD^l)7E>c_f=m=-&4QwCecm(n;AF18Aup}81&wvzh!-^=+?pQE4Pzw z&l;XH^fl}-qBgQHDl|GYzG|Feyl5hB5^6GF%4X_hT4zRVW@MIc_RCz|{JHs>g^WeC z#e}7hWq@U$6`R$4t7dCzYa8pzJA`+P?-bp^+UVKj*!;SyefQN}l&zX=n(da|WxHo~ z8}`cfDfa94l<%e7`|hCZ@YG?$5#gBXxb39ol<tIf)^g5rK6KG@DR9BxH@;u)O5%FQ zwcd@+&Bd+Doy$GIeZ)h=Bidup^OEOtPn4IASH3sI+tRzvhXD+-Klz^Xjq+XcQ}N62 z!}uHdR|e1qxCML)6b_6HTo2L+$_s`C+Xi=p@Ps@FSqfDR%|=3ycaa@od|{DcYvG#V z?;emnaCtBoaWNt#;^#x-hYgXOk>QamQJPW3(Nxi1(UUQXF<FmbkDMNTj+Kgi8H<av zi|dcS7@r!CO|VVqe=PC%#bbQpy~NK+vPqdwh@QATnMg(?ze}M_2})Ucs{6F&8TYfe zXFs1?KkrMGPR&XqOY==zd~y9n-An$LPhMiv9n;4%u4I(I;&>JN>hQJQ>(R{1nPpj= zS@BuNZ=BvtWou^F<Ot@Z<`U)l<$il>^tLxoKCdvJEk8d0w7{)k@twiDu0pxO!Xl2M zq+)_%-{OrDi<05etEF{iqGj3TOy%+Kf%m@eH!JQ`OnlJ!&`~L0Sym-b^{Se_I<5v% z6Ig?)b*x>iGp!q`zgFMTpx98|DArie#M6|~%-EdNLe>)5f^Q9K{nh5tw%zW~zS3dc zG23a<`K9Y-*XM4X?vFhhJ)OO(y{#XWKQ{GU>Z|XU@2?$@9jN&v^Qn4JX0ZCR?B|*x zxuLpYh2e$~rID6V#Ay52)v=y0+Fu67Z;X#j7*0%0T1+lZ*-d?)zCVqg@tZlB4WA>N zi~CCRHFch2K4;<lLiwWXV)N3~rBBPZmS<ONSGHEYSFvl6-zdMOuEW<0ze|2^+R)e- z-ZbA_+w#~t*^b_!-Ffxn+>ff=%e$XYCa5*E7aG5pu+O?*@KgF{$1nX~^9Qa6Cx>wu zR?NF2`J>)r<Ky)ce=ITf#p(Ie2AmFV4)2P`d)Ro{{Eh{n;E~`SxN6=j002r;Fm@OQ z05bDG)@HvsAb+eUK@9mvJ|q7d{$u@ib^(kbfP%{aU>FVn*V_Of6U+i2rU1`pA^7df z0D+X$Um2vFCGUGulD`3j!V3&dFnIhvCjbyP0Klmi9)I#0k3W42#`@C$(B}Ean(&PK zYz_M6Y^@IrDY1s1o&PoA+W=Y;h!P<$6v7J-&_bZJ5PTN^2X!X|0~0X)$qXTY!U&0o zNl3}aL56A?fB*u862PE@gl7wRNGNz6fYB1t@kpo=(ciQo=JjNde2|<?!gr;jozY+f z#V>W&D}t1aiJ66!O+ZlS+<9SX8Cf}bg-dEz)ipG=w6EQ|ZD?c+hGe#O_V*kdot(XW zeEs|b0)rk#Mn%Uwij7No`s{gX+KZRzIk|81@(bP-7JaC!s;;T6t8eJ&?CS36{n*z( zI`(CJVsdJFW@&k4b?w{w_l-^T-u};D2ZxxW<1@P;0Q3*D{<7@f?4kwjB7ni5Fyb@2 zAO!wkhSI_ac_fJFRBsa7c+&GqJ|JPZlAK-9PRb``fMUGsHA2S3FTEsyJ~QpNW&b_H zBK|*F_LpJ*v}+a&pb35_C;<VK5DJA75)p!lhy;wbh=@qYNPZ`>e<X_EiRvs-|CR8d z5C|v(27?iU|7plc$!Y$V6aFi>qP&Hl0Vtpluronv0R(XRF(>XEkQ2AR8w;mz5<-=$ z;ug@Box+Ouqe;S@){$FmbZ-{Dw7IEP?d+h#jwL(r+cgtE<2juNE$_5Oh&<|wyV-Fq z`-^0z;rl@1>{2%1Tu$5{IbVmjo4$XV2V$FegILC>SHI?q&3<VTCyBD^Ap^uO?%W%K z-&tf`j}`d^50EIa2McBhW{Z{kl8W#Lh1|VUq8ON9AnDyp;SsbPt`eWq;)`<F(M3J% z@vqpZ*reNF_NJ-G&t{2KP;TmWNZ@K?W>uE5>MRNv>YUMx^>HV$vZUS#^%*#D@m=dz zqwn9lh-2uWA>Ev#!T_6&pqhPOmYLHo-Fwc%%YKq2rcWiOk*qCYaV(eia9cO%0*$H* zUArq7ictu^A$t=Hv4Ab?aSS82_ViBq>&mZ!+Kcol1)>+0s5?c<-v`WgGo|}3vdXqq z9g>lJ6r~W}JI9wTEHx9Z>=w5t4haduAkebw+gzk*^|I^$FZD+E+PXWJo=5v8YC!AB zTqYgCMmTE-mrFURUW}UFZEuKgsxtupIeSk`P8tMM-W=oHcO`uO?Bfd#rO~<KS1Uri zGK;S$Jin|$v;UBExnCPt=X8Xenfh0&{Dq%L)_zqu)+ZRzV%6s6NS;UyCp(w(k6d&K zS&+XuT>bgEKjo^{4r_p&DZ?d?{%a|c)HlEA6K?*+_-BFt`_2gb@b9h<{j72>L?1QT zkwmGz45#U<pFQt-KQgdZawSrAih9{CP8mi089|O?#-uGz#@b@W(M8H1r$U9aJwo2) ze7^aW`YNZPo*44p3b2j)vOw_5J=TZ`Cx&WDQYp{9Jd(kyZMryTaisR<8C}={3+I)G zO9(<7EQAVEn^R$_XZC)w1Ta4#s;i-7ry_qNJ1y}PQ+Q<liVkvYuO8xw2Pjlgxwwd# zg;P9Gg9lcfe<7j}mAemBs5d9mhoZ}U*<bFyVfs<)t$~iDfb-ZUN;L8nekQuybK|n| z=pTa%ese_BSAotv#`m1;<F~e)!4EI9Gxj^Ph4y54_*$OxoG^*GqWA>y`8*y##X7a_ zPgoCzb$sx(glmoQ;sH~UPk6vnvgsr%Ef4TxLuR|bn~hym`BYGprDk0G67?fVAVOxa zo!3-DHi)X}1P=&rygTd1(AC`;+wM%>`mD}YA}lLn`j;N-HG^e+kC7QSvOta~@45<( z9SzQ{NgPKACrZ4NQ<$6+dLwq1Rx=}5Sj6JfN-?v7u+)e#;K#Yu)*?BlxdfjvU#d;? z^`gI)Nh-b1hX;xiuh?Jd{jjn_<`pV6(=zmi`bn4zd3mInb);QVcaWVNlD|a*H?$|) z>sLND=mdO>)zhD*2!JY;_hwZJ78E+SR#8W}$><ci3tVl8wQ_J`(d10%k0Yn^d&dHO z!*~L<@?Q*#Q=!MiyxVH<kNTFf)TDfAc24Wo{TrW5p!BOT7P;O(l}Y%aBWGi4C<CXK zPJD{g#shED7{5PX_YV8$eHjmw_aT0!EH~4Is|26ss7P$UNt+b&&7I}EOmvnrQpZyE zI9M+jXiX`9H-gDMxh6a3osKoG$7n8_{2HQZP5kBJN;RwABs-b@`eaa8&s)&)Dgl3r zd2})&fUJ<6#7~$+=qEjnOm%7YWQaUKs{10JOn#cLzl#Kv#|+J9Cp3&)OgMh!Umo{% z{0B}hL>jgJ@J614%W@cDbB7w^HrkBxTK^J{R{xL`zg1{MTg8q)VGe0kN~Asn(LO}v zXvXw%Wfl=<RG?R1=Xj^FdTcjd&6k9Ft<I2|BHZTLMT5?777x--{ZBJggiS)Oqv$m8 z04W7hVBANtX0kSUcEI&p3hCn+1;ZQPRQGAy?`ZX^V(`H70@-|sAm&Pl#IjC?>92CO zg&AjwnVC6tpPCvo=eZ7nKpTgIrqS#D*{1I+YHkF;J;K`Ml@!AqmJqA~IllJQ<6=9a z*`F`ptJG$l4%OUk8+zC444ev<4=p4j<uK3D4lCCYmBt;RmQ4s3CxSVb^k)L8Cc{yF z`XANq8~U&trUkc5@Tkfh6Iy32kp9^DnKc`~ocO#*JO1mn9=U3<Y@Ijefn?`5JtBQJ z8J><lmoGHDnL6SA!;Gi@#ziml?=6S*S<6XND_wBvDpIT^>Q-<1w$-`O6_*DT`@}k9 zEOU2n(&dvUceK7JK^n0LM1?(8kwrTswcgeF>fal5^;oQ>a=JZ`vWxn@l~&mbmvb*~ zAvy0-9SiKqCFAdJ_uE}yX>yVm^t<xD&P5p-vhqJ|1H81$6T<Gp=`rrzt(C-&WWvZc zYZ|&^8$KirRlXMEov*!V@84r$=a8J<UzH_dz6DP$-WMjneTocX(>bxn4dGbKsVV2T zc0FaPXw;=AYW9ZQ*&~4@2?%xKnjCdcGal%IGhoS=$^p4<O0`?2O45AMH~rm8FMYVN z#l~pTZku|Cn!~?`GH^6R1SOksqNuZI^VVD?UHZzX__H^C9-@b=eOupl**n)g+*Fra zgvNwY-7!K$IVrMh(JObU)p6q*z2CG;(e#H!Vnss(#_kx~3)0M@Vxwa}j8&#l3p?5< zhaRp=V>qq=bCXa#!!Hk6(plgi3GYfVO1_8k++JsGUn%p<TEPP|3yc1pkAvJPw#>%9 zZW@&sEQ_2oTfDMTq*z8Hufzy@u5^*;C91VEow*YaOqc79W8ryfuE-V+OPvtU^4Bt} zsYn;Cp$OCIL?Z78QBZ(vUwlffeV#pLaUCld<`B1KKKbJZZq}POCgZ*rgGv1R&vTTk zb`d5|cYM#!ye-57Vt61lLh_JL?w|@QHdKh@3Nb(lbz2fPX)fv-DPJ3vojwOiF|i1v zNS3E67~_L^y-jrKs-~VbKL~t3Ci;tKnN$op{3C+I^`2CDTDa-(yD5bgKRG{QqNayS zHSzfv=)p9)HZp{rZN{-7Gfr&1GTb$q^ij0x9^^r^pD-~Qe_Pvf8jkd22|+oC>g-bC zsT<4A(r+1es7!aFwM?IpziY$HjeY6I?=SRf@$UIDpVdZdbai~nT0`jQk&O3Twe>9V z0dllSt^o5Ad%gDJrg_&zRf;4Z3I{6w7)QtVL4g@+z6PD?Ka--eSX;P4>_>md=s}#V z?=xL5^d`?$SE5*aKA#Bb1pMCjzoMa1g|6&QwYIA;Rik&iORYF%8SBPm3Vlx0Dzs{A z=muXbe>_r!sP^i=F+{{{W)h#~AQ8PwA~sH(iHYvqwc&(%RU`&?3_UIje0{{yHU#Ck zy`R@=QXkr<7`U-fzI?7W2WJm&o>?DJL7bN2h(qKjzoB}nOPVM0oV&&NK17wl#N^NY zsMBpj-KM*ZYGeC_2V(F5EtVOr*}1VCMzHC*=||BcYfn+m5kJoVxXtGdiToW&5<c#T zFiAXMZyCI79jyYNbDpQdD78qqx_$HZF6k>&g{845URSr*@ObB){EAL{(`621&H_So zP+{HW>|5!R5!UTo%xHdcC0QL`!Y-P*UhWQ|RMj4?D|Lp3aPd5+8*F<6)w`Hk?4xHt zza<4D(^N2JH->PUrPjFi27g>)Q`v`lMZ^(Tb=V2f#=`NXy*Iff<?`sm?kosaN(-?h z_9d!nQmLCjL`SOpu!fso93HUfknpXUt3?Iloo5nIaFC37ScZw?@eY3Tf~P{e#gmT4 zIEdgP6~mpxj3n>$=OxdKGftv^#y{P|>baq)dy-<T>B{cCN0k+>TeUUL%$v-!Uh|pD zZkr<xeY=vd!rKyra_Aa@VWlw-(MBiT$8o0dudNuPWsFjsNa&tc(=fID=yHXvDDWY) zox{kL;DwkKT}<@Bj%(>gcvLFUUe}w5)@C3C152{t;F7RnN4azcsUP+%#BuR8*CYnd zlK+hU`c;i3vc<;h?N8AvibT)kSM-LI4tT(=@f1zEL7k4xS`9W1SGkBqU@`={TO+G^ zl*fy9Q%GrZo(Xz3i9D$Hx-8%Y;pVV2(tLzNV1-e6xPm;F^r7U?JgsD1;-?~7R_YfB zfYys`Ycc)uun3Kia>Mr@M#5YMiFAb-t1s0gNqr5xbMATD3%5JOx7Nz^-WTBk$|kw& zO*~K))AiW;R%PPmmX3>irF0z2)qB^a&M6UfgtNb6g%;*Np9>ia^M0@GRcM`S9aq20 z9LB5!EeH@ES|%zij8}&1>$}chb(b{Ih8S^^r!KyNw5={eEm|Y-fa6ePi{$w7a8<Ky zUvnLgjn0cqhzNan(Px5G{zWk_J%DYm^-+-3>*ey9hwUwlXbNqyWcz$=U&^ZZ7kL>{ z66~H2ai!YC5zNHHKM~bQ(N&7&QN~=nsjzs15&K#f?Jus$a{Uz_h&Xn5DidD<2XAMz zwYL|Rf=ShD`)g}Ayr-2c(VA^7&@cCY4l!-I<PU~2R9>~yoHC$I9SFay9RAs45zr>6 z8=}IE!illEp<-F~Wu;BY8hhV1<!Uvw+VAnKLIR-j{2V?UZ0InqsTs`Ja#(wS(k*AP z+%Ffbf*;wlQY;Ikxg-D@=zW5VL9}sWrv*w<+0IB|j1%MdGl!j-52k0g@}22rB1uid z)kVL)9!aOYs?p$mFPv<hVCL6yjK$AW>Vs5_@-n9vB~)+1#KKZOkvWz0ZtNT?*ZW;h z5arZ*V)aIgBzh<c3Ev>LK)bz3Wy**+f5&vyd*mhqlkEcFp`b+IvMxTet~RpM!Zc+y zGp!CgcKy}lH^c^Wy!nvekul4y2YZR~#P7>QRzfE#-yw&&C`aE6BbG!Gdv)dXH&U*> z@0KoIXZHF*_PE8Yf|z>;+DuzjWjTWtyY;v_v>Y>aOy8bW1@CC(+#-GB7*uF})H<&2 zN#FS`d73qPbYXj4=)Jtm#-+*l5Q(M63=!op!<w=6FIfXM+G{2=Mlb>zI;n^@B;Q1f zIHoc`#150;l&@Udv6ACFqu#IQ9>$gD%KxtL>3a$x_V=BcBpcR-nDV8h$QJIo@9DMi zC7CXRj7?Lr{TtO&xffj{iP)bYEqwDE<Q7o8X0Eg7eYqyrW%&@EB}&__{S~*DE*BQz z%XS8aVT-yCFq2PlpJ#Y6I*V)3g+;Q{ztq!=8lNW#4IPcY$U3%TmPtC%{W@3};IE4D zyo+f-Cl<OpUtWF0=dq`v%vj?vQ`73Gd=Ei{W%=UsRn$)>jKfU#t<IMtJkTio#-Em` zM_EL%n)RAuGxIv8re(#Dy{>zw5!(uTiSqA;&**4lAScQJ3(L))YXu*!w8qJ?Wdz(+ zb$07ZA0lUaJ<><^Xhx$qU3O)Xr9AKy406B;yp;>b<aFTQB9LZ^o{W|z58h>9dY?MS z-~ZxuyGaVuz{4&Du810hGaeu_3(z~D!aP{Y66M&iV$9_AOSCg-_veaiG2b#q+;ikk zI-o)+CHK4=S`6@~c|(tj3|g)%y)|7l=$!Ms<(d(;?BsjZ`&&s>&jl)<`k$<XwNLx` z2h9(>K%j0}PtBW3lUqvlq&OQR`I2pGYir&&IN3V9Xv^^!c^8>x5#`11WZctZq`cdU zm@PiAKYfkWP~t=jbw?m*gp2!Z?A>KnQ6lf`r2>TqXzh8PTRvpu3FrRSat9AYXksN$ zR4-4?Z?q^^5AWzb3egTlx|e6l1g@Ezr$w`Lz8A1nzkCH+>v^xoBz+}qW#NJJ!py^E z_1^8rC-Ms+`k05V9ur7GG`%z2rp9#nZW4!-CHbQjBK32ISGn9%wWyMdjSaN~WC*8Q zP8nOaz$p&O=KU?9z*XG+^;bL~%Djw>#==hDsgU5t*mMtyQNdl>#jnpNi@Uy!({U-b z<rX0Clq`0=&skP~MGcrXg2rOdCmzVbtjfej>L{cj=B-*;e}tNiJ*+m3rl*b3DN?vY zv@Ih0<2{adeCgQz(<i^&S)I?k7%Mj~gF$AI`##od2dBg4;l)O9J4AZ$j?v-(LCKBL zhs|+`EHy~i@%gZJ$q=D80kb`LfKx>|pkqWF8|=A4z9T>A2diZ=8aNpelk8V&Uh9{g zv$r*DBae$N-`>evUOnCsGZ!~5PucL+F1z8fRiu-T{?>RTRCBC2b+W#gwIhS#J9vx2 zm5O&H(Uq0dXx(kCwo4s7n7(HKl{P3+nP8pewI_uU&5^(_A-kBYKlHX+PtW#bN?{n6 zUt3$D1hd{PY?|dA6jNk|NSc~;IhZV&yZG5_-PD?8m!i|S0l=?(#{*y~useo$gZo&H zt>jZWZv!d~->3AWMfqu_E%-BZ`UD!@5ow&u%TW3d&F80$?K*vgSYw5As<epXfiVkT zj=n2Zqm}xtgW4iNzMwqI@58sA%WHq%6xj28zEGB`(u;;;8R-v1QLAirJ1hQQ5)0gF zg;Qnaq^cz9g<+V>yieB@Vjm@Ike`D&Y{R2knUI`VCzNhhmtM5l+@)_1LhA&&%%9{I zEF|9cC}%H6^S!g_W#MfB*Z2{3Ar~=W1<3AshL-p%RjDgIv{r+~ebN#iCku9NU7@=> zu-{YhPF?K@@%<FNjx3@d?k#M%zSf$_Xe>#&lf%*itqYT%MAw?x{MWuJuky*&2b#)A zD-xxN7i-&`j?IDeVcy#1y<zPf0ouXVb`!4##hc~AI^DZY$<IuJ9*CiO<9>t)1+PTc z2P<5coHw2NYB^0zOj+Y-Yr~tgtHG*w*JVR53Q^_I-un3XFjh$sRoUek5WU`UcX*-i ze&WQEdvYdGTxiSf><BrD$~y@hY$)m}@kKn)(M*3JhthbR_q|5$phaRNr+Fgbz2?Q% zx~A)&(-fZ_Js_1JydPyqK5RU&fYP-l4lrLj+z~i23=nGP7K_T+G7GKF5s8J})RcV# zQ*C^NTUSCyX!n&eikW@Ll`D$(w;YQ*sUB+`F#BeZyOUFrVHT)3yS|Wp!l}++#P`qL zwEw#{#}N3DKlX&npW!Zs&`jE+84yfJN-gPK7PC61H%2DQ&BPHmmT43?F5VYOe9NUb z{!E|9FG;XKHrZCTmcs8izMbBpm)`BMpfE*TLeRVW;lXhuf9*QP0e1PS^e@64fF0s+ zz5A1hd?zW=&v`|hht~r*uY_F(caXk+|HsDgxi{1bm^{>wDyDQv0V`g1k5?;Ye(AEf zyJUjMJNY44wAEsAdplfe!%FZpzlE|blXC(KjwaHBWT!oI%NZ$|_Qyd4`idxGba&=# zM%+nI9Ds>2L~HJdqtm}ODP{+lHbLHgy{A~**yCGATMIky`xvExTs6f5PU@lJsLG|( zC{%8TM}YNJ!F#YiZ?AsFA42X;+e9800mHtIaLF6mKZDM8Y|lYe+>fXn;lGZAa48nX zImWN`9+Ww<;(@HQ&7Yv3X22TOVR3;9dl}V@2O^aW9mRc`uOY-t`q~%QOijE>N}%UC zA|U1cmccmEot08VkKcR8^jDY4qwc+8b)cxQ>U5(rVacMc++n|6yjVdPA?A5Ps16{t zW{}-kOfBaFma6f<qt<&Uzlau#i_@K#Ea%KkFK2VtL~59E(_Lf^Mp)Dv@vr~lYj2@O z;eqfFi2{sT$3)JM&KTd`Qi4}d%BQZqPoH!olv1?I(&GRTaPyA%z~hSw|5k%_*H8@G zB$m%gX?Wa-JjF!q5|4n<mGBSG=QI{e$WjtYBx<+V&dbjay&tQd^xi+Q#<-lggxnY< zj%c!9cBQ^DA~x37n%QkKXK^*)&XqTxm&OfqjEIR7;(qRO;s&o+Pgmn$Sh-9V0i#;q zx>kR~)T?!-`6(KYIif>2fVWi4;qu3y<M^NI`qwVKe>$$RP&i=e=zj|%>`qzyrv9m( z%jK@u8K11Iyc2eYYCLCJuDw4^Z}N6I5uC@8q~HPFOQ)aaoj>zuY~Km35W8sIw=Y$p zg$Il$XVbynAXuMwUZwH$kR-V1eCX?zEe?kVO;U0(DsV?Gs&Q=bOP@i>iH?<fX6M;h zjD&g?NpC$&c#!<XxTnE*``3vomU+TASZ{j2o6QE5U&?18n)sfhx~jD2)_pI6E8&fY zlx>uoeh;mYBgP5}x`RIsNs84qb5x8mgeS^da4H#H|7GSbUDpupm2!H@T4OEh=M;9) z5?v<e?{taX7zc!la-!DF4=gZ>%hnMH>hdO`<>F$U#>Ubc!Q;U~$t$#d99j@bP1di^ z6jnGD%HO`r&mZ5JSrC~s+MUM(%1Ie55G*HJyrbD=nJRagyjbT&lW~dmJF|G><d;1D z$>GnG9HC3SWG2C1kYpI~^6zji9CwJ``1*&s5hZgM!X~EgpT{1@vemt0yBN+HlUt+u zt>t-6w&U1#J<_Ol^+9od)9!xr&;#kDX*_Tcg%K&)oUpFM1F=5!N<zig#_|&1^NYHO zpr6mOuNH-8*MiRYs72H~w5ZwP&J{pfokfS$NBiK;qXk<wN25~oD$I&fuBo!ST4`_c zT?WXC0EJ`TIM1inX&=2N_m85y{KEE>!Fcb0a;){I@0#!Pn-^SdTQhE5?f!)BC2!F? zPp-(_<8jf>?Q-9cG7$~a*0(Hp2s=6e{T)+?N-}<^Jy}M3w>e^)-qq`&<7$+bB%>N- z(-Rg(o2Ngk8GLoAj!lk|Iu=N9yjuvGZ6rU|dSb`zhxDE0j$S<@CUhFhoK8{Z+gEo| z{MJy7E=?F;)QJfyCLzRonNuS`paw0o?Jqk(ui$f)I+@#EpL!50&{Dn4aO=jbU+w(Z zDeTw6arNY8KxQRs*5V6N;kz}xP3BjhX7L*-qgpSsb}0Fdri8P}c1{;p-z#wmTSKu5 z6HQWR#U7kYOXj=|oUXc}Mr-<0;c6k}H-ZP6LJ3kx#_gX5mP+NA#>m&Z$QrD)=0ogV z3?gi0U{Iuco@SgVW_b|y)CGL4RbA=PDmNXV(B4TQMNq<?r1Awv6Jn3a4lbAqxieXX z!7AGFemY3ke<JfK{&eqq-D2{qCyHW$aEp<vK3JUXvkYt!1KOce*Vqh`yO_kf;Wew8 zA7CR!=V<cb<JH(|*tK`JA1qv`<CM}tTQ9ZLt~7hMU&>EbZj!*b-%dPOyz#jNEy733 z_2rKFni2<f&_U5TGxNTM_28&qt<PJGDpuf}I8K)^1|>@H@!#hM%FWHg4bh>K7q+;j z4(;q|KL@VF4PY8k!8g&}JxR+$B+j{NX!n@jW6?f~Ufy4O|G~82oD`$kW}KhSPq-rX zLP!-D|Ae4x@(K#N^cbc$%w!T`FFg6~Dt*h6akE^I&mLxHCjnI;e?qd0J0E~p){CtU zqp%F1?^08*9CHg9z$Bz-uJ_iHEHOsD(C@T;QlON6;4J*CX3w85Ev)V+v>n~Lb2(RF zJA^Rb0p1}lPW<a}%y1r@;Q>0UXBP-TT(O$emUKx2AEbR>roaPj&k}-ULYeKe9UEMK z{U|J$Qgsb{_uS=WPOO1;KLNwXHv30?Jwf-PhO2H*z`M@d5NxJg<#x%Yi=E;2e&6@F z+Y8d;*-bN+Hej7uzF#}1;4-*gu-Ynex=xI)U3v>=MbZYeE?IVOj2U-dvL6g@xV=w7 zEy<`YYwu5ZpRS%2%F5}ox-flf(S64nRj~A}9Ob?u9*b2d575&a<sVa06nrJ7GRw8| zF=DleRy;NN3BW$()7PW-os%&mKT9&q$rOi(Ypk%?{aUN(i$SvStqx5kEh6Gu?+s1Q zs$#TPG6vXpY_Hd~q?-C{2&7Jx4t{#dZIph_;z^0rGmdpuj>&!ms&Ao2udOv6hd?K7 zXRm}&>MVslI^1`jO!tpn$?`4T-jhxk$q2CfTGdE+XW)59gM|M0ai_YI;9PzA=z^#; zmU)o1hg385tNZ>WcKRsoL>dX+O<7xLkYCalcF56+y?&YUS>0S5!OGXZ$3*N@Ydbvn z829xqP2OA2%XCUw)KJ=@MG80*uBWxi7!1}EGx;#HCk&WQ%~8=1e~G%uCgo*@XH(i9 zI`t{C(F!a{k{Z4@*Q)Ga5Iau2IW@@1$T)N@5Oy-J^sg<zQgwfA+;<6<LvA^2B7Km9 zh!5hRGL%?{NrcnxOmI(4%LPn)(UeM>lS@9IU`BzN6tqUxCpLpD+wiI*v!oQO<>8|+ zhoq_fWw{ROY3Q=?_F=j6jNyFi3ejXsueU4zw?U$Udo{yHpHOe(qb=%g-=~Utj7%K4 zC2D)_r{~&{6}St1oOYumS3vrgfx*$+Q7mWDt1!E;hd6!=@rfcvZV4X$qg32Cj9OL6 zpWdV?#av3q`$Fe{lWpurvPF(J>q^E$t;AvXsxYkJlQqOm%t>JW0!C_?tF&aRmfX(G zjpCAYpVTKIJsys06%qn>>Fy*wkL^Z2x+M^&uo7@jmEVh?q(qX1i-h|ZQ8B&3ZSRbE z?SfgOtZ&JwY_tU*tE<eAL6*qtSjE#!6#(lDr}8D&n*!(Wl4|RSLt|qNR<WPig4y|n z{`7CmpJ_hmk+Z)}w)*n<tpMA4JOF!(T8~jepj{4<s!(E;p36xTW}9Y~rJ1#gv59Yn zt9M_@y(%i^@$z5uBC%KCjz;JPfIH&0CoBO^EgC0@Bi(#bo4xsKj*;2k?k@FNfuvM> z6@^ML-FLZBHy<Qbw9toeVjNI%9W)r`CB2v(ta47rw_6mDc%Ae5`KEbZqdGi>`c@kE zxFN3qk!yx@UK!1Et_wv$ts+iag;Kg@R%xy7^**iEZ9f~%<y2hw&rjC>M^D@T&!gc- z!7%?_n2S@`{t@*l)*79AVy*+mtRKU6m+`=3!20it|AiAoPhU30Diu;lwQXH6fTW~| zt0gWbL?}k6o9Oy?Yh~MwmE_xC?`~?fAI`3-sAMEf6#vh?Z2>L}9(Lh@pJZSZ)wY2+ zVoSnkDe2*X%f$#R)l)o>T$FWClZFQnPw>D3IM=Wts;gR$urWBCIA)k|=--&%{U>qY zqpSaEQ~zE0-){fA!~H$f{+?w1o?idHvHX2Y{X3ZWN6hti81>J<=<fjXf3!;aJ9PeM z5BYn=^7mTp?^WvG>+}C(b|#4NW2evGkNx_3SG;P&4;}CQ`hy<NV{6VY!WF$O!gv77 zixb1GC4nDTfEz0blZEXRe#4sXB(ogXXf@y53_rg;d%I#=O1n|yy;1>=HjFvUs~3j{ z2Jk?SBe=g}-MuhZfd^Xm@Blg;<KEGVGU!Hsdn)uui4;Y=EFLTUEXA>9y1>cF*KG*W z7571%@BE!$=qq%><TxIvbl*`!jU8m8TC$g^f%pOPuXL-^o9Q<SNNT@a5QE5@<{J5S z?g*HgJd7xyDl2xTEXLw2PYIA)Yr*piGQ&GzOU!8h!wX@%*jDVlB#b)N@^k<Wet!Wz z|7I;W$$3jH^zRxX{<TE&-8f(`>^~OE{f{QPOn}ed|0^ji_9(d>TA_qir>t0bWuqg( zAQC>cQTy-~vz1LFZa7w8>ds!;z;r9?KD+#9A`cG3wf@>kl(B7X!`g^B+DW`ScQ`3p z$?Q!<;$X$J!&9M+7vZNa?1SkFt5uD$r9W&=&^4s2-&Hwif@!T*tPEEMD_mx7z3bF< z_hLDIG{f7Px_>lEn_vgnrDt~PP5M4+8OxjP5&i7frxYy{;;&!S%Y`X(Fs}>n37*zp zN?|3z11-9tv@nTTRn8{9`=002xgKZ$Z>TT)7tU&x+r<Nm*~}_JXuU^O;?oO@yRpa% zPA&AA=9X&4@Aq~}C{Qaw=8B{gJ457Aew%UqYOLRkD=f+;an&jx|6hgI=}i|O<YIzP z^a8vM$Cda9e>^V1czvJRx+GRe^H6`RO!W?vW2anyH;afvO37Q1;&aD0s-Cg=*$HO$ zg4==};>)`~Ou#L{qcF@M92JjgJV`|y;;eB}Ru$@6L;qF#fxSQheD2Zc_-V-1UEb<i zyA(4P+7piRO{N^xV@Z)}|Jc45A<v!IWn4cQwP}ABEn7?A5$h329mxH7cK@DWIFrlU z2r2%U2W`+4?L|QGsTm3r-0d$BYcMkp%tdA1qsByPLibBTFP6*;j&OTE3y!&~yb_^l z-S>>`@Y;Xb1$CAWyIxBz?smfILvBu#3bs_Bx)UbHRNZSoC-blhXY`lp)iY+JCVSb! zx)x$GSuja&X9ryDmtRW`-Hu>AII8FlB>NP1q9eA*0*%lJVq>bU)j`VDuHu%eT>^2b zctF~_^<RCa`EN5q$<512{|y)Z7XxXydguE0bf(;rM@=g>t}?x8jc$10Jo(PohY2gJ z?Y(&cN!t}@a87X2dw}G?3<<_5Da{-;5_MUGmZ%??7rM4i9C^BZU2{8m$l5F4=V{#B zV1@@2XZLKK`{X8nUW)zSX}~A3O+!ISLqP@KZhq4z9Fm$(lA13(dKUN$;{E9`_tRk! z(w@s6u-I4rzOTHQ^-}ZYWoD53cY)xM|J4%z%T@ec;J?!le|zbF*GuMq^zZzAjs305 zaQv53Y`zYz5EuQ5gpcou(+7f&ZNA(SzJ`m#EA`;(%rwZ^dG^!Sc;KGPYxb^IN+tI) z6DPK^LT}HK4Yu3<Y*$?_WWPW3(xb~2TjZea(X^N>#}a^FfME$yy$8?17cL{VxVU0E z{qO+I&hmJOT%MlfNB+dm=Z$8T!!9JmuDz5$vi#@;7-=Akk)3cL#2mG7V4Ssh*eg!m z>*8HFzq!1BBgIi{C8hu9<IEtmkCl!`f&Bwy6>4R8qKq2|W5$Z2vQ~t)UD1-H%uBB~ zTH@PrygC_9c_&<XO)oOPa$GAV|In}2Q#305hOUvc;czWWynA~z{z#-;D5o~peCqqF zv>LB4;%e8jt#l)#OYm`78bU!7CyZ4XZ=sRvZgx26M!UB5I^F+V))3+z$`#Lu;quxT zbl@yVt4$J!Lp)8<dYpU2zLWI`UM<wqngFl#-$pOB#;`RZ8PMy!1C{yKo5KKqWklBt zU-Ec+ihS97oe?Fh%)(I?#buKZ#KIPHtw~#0H}C-9Q~uBjjwS@(3SW%Z8dGvOt#(Wg zn(iv*kQ%TRoQ<WBKm};H*6&T7?mC?2BJPzZX3@3KU?nG&WR_jnlO9`>?Uq)*b4y2j z*qA(Y&fPIuvNa@Se?q{d&uVPDLnkFx9mu%d{Mj)uxUj`8^|;u2w6C`4Tfva8zmwMx z9w4Y~n4Jx6w(3>3IbjR&!^o|K5n(15HG76=1L8!2zd!AoF@Ef8Dt^Db>KpgxiFR&J zjwzVn>IFHwYmHh@6VZlAUAwc;gEq_qR3MUj%L<CIK_M4(bzUgR`5%bv=-vFnBLx-c zD7+INV`Ma!J78JF8J{?I11*kaxV0jb(8QU$IZs>xKBj)nyT{&2vnf(EwIz4yl?3fK zGqbMvbCFtdp^I8adKr7qgIeiY7;s}@&@d;kTOmh5m&cSM`Mz7z?v#o)X)hILV0>mH z>nq~>BfiZqhgK97z1y_F?m<}ztt8KTYtFdtnosl(&gLKJIhjs#j~AZ%z5^E(FR}_< zFwFZj6rL8wfN-#$W!pc{Lo+AJSU*CNpv7a@yMhEhH_FXaE30VbODIWNykgdvIsypc zJcd`S{huKFTM1jv<3^4bl9WQxd@Fhh+gQaVJ<iREAeH2N=W;T+X)!v@lj|v@W?Cd7 zbZSvkP)2SQk?-ZxyNe4u&=vO?J)_eygp;e_K>H#3`vNKn6V~bH?-VNVJp0ARJ&Pe) z2`P@(c_yj*>|VVUw~q_9jZwM1tt3s%D}(T5su$9Y%?>>?Kgt~=0vz`5yYs8T>Z9hw zP;Ync%i{zXZ;J=?yU*)~g(=;y8mWdG7x*|>irNzXs26Wc-be7S9UW!C1Blxe=n-xU z@nMWjp34hyEkY>G2TGL7?5*VAFBwPvqAjS@zGM8UK3zAlyooX7gB0B=@oja3g{k27 z`Ae9v6EAEi_*s=6z1|9^x%Q~E3B7Sc?~L1am!m72mx5C&K`fudh0l4h{Hwoa+&dS* z5A$>|BWPxQ4C|tfUVJMhPQclpi(5x%e5Wd7Q2cg1PySbdK1gz`NR;v1qZpt2EC}Nh zUM%fteu%)yb*vF){N&2?#)f#3^jDdj4BdKdSKCw-m7YiDpW*h84J5Z4O%jS*v`bK_ z35B!jAII<jvsmrE+wo>~^)HR7ynrvpjb28zd;8+;B2LZY$5KCgn$ng<TqmUw$4kL@ zAY0X{+6sX_Us2{5_N(I%4_t|O`!oU#%R=*gL2N~49g~(9xU7YS?&nx19oq=uB3w2- zR9QJ5UMA^hy_Okyx0MygfpIyQKYdVAJw$*JU+T{48HySR_AfNiDF9#YC@1*bp5V?j znLPC%CDvX1(~Zd<2INPpS@r>uc=hKypU_QrwLVYkcCtS^u=xTv!@k3Wmf?YR_11fj zRrZ&Vm8ezZ@FmN<u01C!K^4Bl(E}}k<ktrKJ3Xea)A-i7VQ}7$n{Q&HpSmhKxeJvo zM7Y>$mMB;iZ5$N~l`oI%z$^877b3E_q!F<=5sc!ScGTGNb;#4M_~H$;sxRSBC&}5l zg~0pK%je$Gxy^rHafv3;6588e_D@iu#Zb4g4aK`6ToH_z+!f2@^n>AzT7R$a{>Gk- z+MGpZoeWRjN{aCf=&e3{Sf%-5Xf}%qVV9JIlSXOo(66{BR5~G9F%QtYH+1Bs8-*mz zHsW2~`SnY?M$#y(z20V;J|)XLR8%nBV`*W>&5NIiSt0wan7~<Kh=e?pEmU%=OQgi) z*|?3m2zSK9g%6hViIq3n<)oyJu0|B)eX$m<TuY2^sny;t%{!T&nexeKXl(cv)G}kX z<sa%grYH5K$s6sy>=0d@G%a%QZbD{j3M;kZqO5oo?e?0mt77Bcx<wV8xn*_!u0Tk= zUGmG1a|FB*0gRjA+nX`*I6<%tRDZ%_TG%Fce=ZP*aG`$A3_Z%OY)B_}892R8SX14o zHEnP?`N*5wqZ87jATy51J#d{=;ls7DIc)DP>j+e$4GSWr2eg&~ZEf63Hflzsh3`Fn zc~Oj)<WtG*LV13+f)hgU4<=Ni#@bax&<i^Z*|_uAtK*d4%C-+i`)ekW<Ab%*bu=7r zFerudw9yfA)3|(SZmw1*DkT<ztGLox-E=ef`N7J>{P3oGpaAoNt~s)zD#-T;XBJ?7 z@DO8uqV2459+P&Wi5s&L!dMqtY|_ty6HO*tT~k%{*nK`kUEn9959*3FKTJ(cq6V~6 zm=I%dj9_hV7!wx!%DJ?2ML8HpH#<b=QfnsA(lAnG;GBAs-)q*QKSj9CPy3=y_$pxc zDI*r06{7+<{eZC5NutLJV~kOidYA#!k3HOq0qWP%gs60Jsk~+Vy_yE|+y=`?=w3Lf zKxCHeM;zw#m2^^4kVEH@SeWCEj!qNWuOLu#a@r@@H+twsjYp1%x`&jM*yyKi?EzPc zCAFH{YzKbkUZwC*Woy~PVC9fxlNUH}lVaZQ=Ok+OMr!TL-SNesv6b!_@MHXZ@P+V= z3o^UU_a~Y#=;+p@qg`s;Q#I2E4(cJ8MpgLFdG}Mn^ya;mpy4c)i^Ec5Lm7BLJHI^J z5xQnQn}pLPUtT|ruU$AM^Z`GDF~+7(c2J)%_#jSs;@+MXHPJ4I#Ud(`x~w4qVnZjI zmXki_#X2HklJX(btJj4CQ$A?aD^G|2l0g%9WIf74-4u@k|K>%A0BQn0oicPYsiVO# zHMLfB`KuV8<n)#Iv1EazRqrj5>sBUSpU>OjJ>Eld9AZbjt+o+6yQj6-#Q1~caTVc^ zfYZ_@#S@+Yp`KbCfO?%2W8Q`1YaPj_arVENk;SX%m=E<aQM|<YR;zowqbT#1-l&QI zvL~#%pp!bL!M{s4thbdCc`@VU`~@W)|I=*$f(6r_4;O|lE?O2{e<b^k)$*kCy@plq z3Zf%vCW1N^4^W5QZ;imgE3<lQ=HU<3(Z&%!c9%@ntDiR=x$f+0o9J|2cziOrNY8dD zhCMex3682fcypQ!uK2PDv3{pHSOI&~Ku0UN0lLbPfb*KD!QSTao{UDO=}sj)^&-`c z9wHh2NGWAYlD%uRmK^oi@de9t&kv6?Dz|k*ZO<!*jt{KvV9bo&PD*8l_ZDUhwwCzA z*qV&}Pc(27>Ag`GN+X@#VyxQ_Z<%W6^;Txp2vqsIOMfE!NX}a?gbRGG`us!%IWaGT z>fP2xw=PQsH0V8oE@f7P=3#8j#`%2fnrf2o=#aWa);<4x)cO4BI`_8FgT#ktA3U|X z=uM}sAi6VlW5S4jBI1I=x*mJ2n7v1HG?bRT_Pi@D;D4e#Ge@tKeAz@I@nWvxOUroI z8<mUV+cD_)+oxIJ*HYAxtySFBqcV6P&Wh(tHOcNazfQb1e?uIzPF*s~r7kg=WQGVi z;ft3)%CxPB+Xzt3Da$nMmdhR=|M0OAtFky)HwNFl%yJd2$;?s_(mqYZA@op0dXEjT zNOUZwzVQ3rmZaIxy9+H`k`^5fF02<UIj`NfWm=p!ljXcimM!wV6(i7!eX-D4epnrL zddbo_CcrVQ{o2?P%(_x^va^({<H!796(bs(zDhS+_L>2CZB8^UwIOhZCo(xAWmP&^ zj2&)@Y_}#?3PhP7+(ub<y2qwF93US3eEKe<*!mKClK4<-T{>dkyeYZH))yhd;Xq8x z9W=jr(3C~)xn}9Q>$6b%^9Sbi$2I3LI&f#e|7h-ntcltzB%sk69Dndrf#|dGH<`tC zaS{!_+rNYdMjc(AR4DaRynr1WA|#U}qlwp$uTNYtv8~Qu&3aba;WR?1VHEbJc|fSN zg_Wv@+cTA1f%5bc(P@uj5jqXO9oF*PCZ<gG52qD4A#2s;Fj}k4nq$hy!%MmY*(0vY zFElGkV_ul_E4^Cm$qwo5l>MQQMQx!04p3@l;^f=4;U(d)2Tu=%12{L!AI{>W!dA3@ zq4FQ&fj8CX4Lu69d;N-B@W4Za;i1!Tnc?~IodBd#R_^IFKLpAE9yPNYC5Y~g#1gMq z2CI$b2FmCZD+|iSU!Up~@fM$fsWvrT^HUVgCFtUanjLRl%tny$y{|(HJ!}$PZe_(7 z-$)etBJ-8+ns|Z9OPXsVgZxAE#IYojkL#|7-v}tl!U(NBe{>ofA_aev;m$X~=2?aO znKAoJ;MD7lxN<lnV%y2ahI*TGY1U*>Pk2jfJo7f6hhIwY>yrmtjl%~kd9HCW9}MkV z1zHY6m6aC^&b0(y5x?P5-q4S0IEG)wY#n5PU($+4v03*hYu-vZQA5K#WS=s=(iQlY zC(!R6O&-AR!fFymzBA+9R7UBcLVEfkgnb#o@g1Bcl$@5pjV!qdC9V2V;#eQF110uF zU-Wq<ZQz+5WqEB1^X`X2Y}-gym(HU5)R1}2q$^gAg4nt$-uO%D<DJKO+aGqS>8^f` zudmrjNYk&IqoQlc1!oU9n(+scD<Evezh{tt(0|!gC6sQov!iQ4@7)zl{D8zfE1goE z@e@YiLqw+^Rmz)O)6N3Lq&~&aQ@@$7KkV5OZ+}UdOv_j=U&?F2^}Y`girERPDc|99 zS=T1c%E#C%eayHY+t3}os{F!%`#@$Ogc`GtDgeiDMqY^SmtU<6n6!?=<JE(RUE@U7 zY8JcHr6qw}(>W5FuE8jO@02U6JDi&Bz2FY}*V7!m(im`x|M_ULlg;=4Veh@enrybc z;V38~9YLBB6_5^s6ctQtG!Y{pO-fWmK%|K@2}D79ldd2{nus*%(h}(cLMTG$p-3l? zPy;04yLom~pZ%Wq?Dx9P_nq^8`;TyOCo^k)vu0+^+-p{eO9Aeb>6ggvne!Zs1P6*2 zvtAS4bu#AmXG8sn22K{ITQO5FEqm3x*QK{p;w3l`s|f>ub!4kFEof&NJhVC<dk6g- zeUg|ELODMk$u%5mezqYoC0@EDH0JaTd8q*X6A`zRPs;g~1m5M6gIkbsJ2ySp#^Q6* z$$GRleG)VjafL5e;rQVpLS<N&jJx*Kl7VqkPum$$J9`E@`xPiz%sx32yXZe@6@5l6 zyXQgBzVNW}jWLe7TA{f{#MXoz91JFl2^Lz8LwRj!hx-&!hm6BzybbGju34C-%O)TA zsL*x5o>cXTS*yQ6MqcBJ*5?dV1Fl8m4Kjwp-Qjt;i8Y--URCoRU8mhCCPexHfe3(A z<0IpNG9WTnLJG&QytW+FG;^K?CEwN+Z8Lu!_UB7QYM)=c8o-_#yb4aIv{G72iOs>t z@gkq&$_<#%qZUG>6u+!_`NqsSi{o}?{brX>ov*Hln7+*EP_HiXb>#U+rXRG}tyr6J zK6lL6AXRa-Qr<ROfr8<tB67`E#a1Qulze*fGr-^5lvdSNAQrRk^dCazxr}G5XB_NX z`nPvXDS%c`A1E!h|KA1w&7QwbillY{Y@Yt8&8W8x-EJ=x(S;U=d-$|ol!bTaE3y?o zj~c$?`M4vrZdr@(O#l4l^La?J`5=@EQgqzeble5CaL??_;x7rl*ab0FX#ji4!Mt6N z4`COybt>|IYQf_>2>b4@r)1gYGj|SZyYg+Ga*6zk<MGU$_Yvv^oA#bMiWiNe2Q95k zcHA@#D;9kJ(J|c{1e4Dow$Q)L4Y^G{)cEv{a--xj+ki&q!nQ$E>V^}iN+?YFO_zuH zoX}~3<C>z1ZF_Z|F;HX(VA^y3a_UYw5GpPNI_?hKQm%n|G0P*EU68O5YWe@GV($7B zn-H6C&tQSe?X}}RJwF<GF9wSX2NL&x4t1-tRP0MV5xq389&dPzh0xD-=?ufC_j_G} zj4yRIx@LV9sw$1L@{c@mVMC}h%#UDR4dy=;ASakNzC&_4=UD71uMT^{)WSzOl+WWW zFiQu-^!T;sx__Qf)qwxZe}|%#sL5;k>XzLIsE-31(+29aSnZ;gmN4`=cY|r^5MO>N zBYlFqo<h;5F-M0l&-OjP=zkW#y?=UJx%c9l<2yD1a`?`91Sh$8qFtyN%C0KQnp0q7 z0UwZ|%#B4|)AKCM+2T|eG4Q$am6ti?%uO#l*+2+mPqlp>1#olQ?-ANy_{ju6!e`}p zVYHuv)ARAP0-b8?E!Yz$$$Y%}wxMBp!DRYQ-z2(xjiLy5={*+t9CC~-I_~aEfC!S0 z=<ve>7q8U`=t>pS^KhybAy!qsx~mjUMVx){9F$3PuP-=2ru$NcVL=FGdHdD7-D-1W z$x4>0y8q@)GNQ#i<Gf6>Nbrnbcjd^M<Hw{YmGR|xEN&U;5kI%Cng|0tOUH(RcZ?I{ zldj`IG6?QRVXx*^Tckz}r>%!$Gzs!w`TBExFD<0-8l5C!H#DdzG!AlvEDu?>E?cnF zo;PH4X=!?5+w{TV+lN@bMIS=z8Hw#1)a*D{kfz*)4n<o-3O?jk+<qn$-34f-B<pF4 zByfX{kPf>?U}LbIV3zB2>-6~JSIwaQ4ojU7=M{;WoCm9B=H-ja#D_iOOG(6he0>O- zr&|`MUs|%G(;{afOmHgNdJk7XDF)Y{=T*@hYV=~=R_->3=9R)-zGrKd^eoT}DJ5&2 z(~uKnhB_eRn$QOiw&r;7;IpGyQ=y?qC1dw3J*n)~y15heX(_zcys8h3L`ar^*DO4I z+dlpc>1=)H{NxS$z8HpRrlk8$usCNX;SU0jmTnE1(w!JH+LIoq^Yd0&t?iI^mOOy# zVc5Xr82$?m^O<Zr{&lTw<W`KUr*P4Td3!XY;pfku!u^`ex6bI{mCm+%J>1Shhp*q% zS-Ku#$^&pH8;}<#6aM-S$PAbd!0&7OT)YM5Q?`{m53~b=<^>4szm@%Wk^hl=sabQ} z+#APb;Nq)~3cR+x)Ng>V0(UP_M}Svyz$gU9fy8FfuEA&$HSa^vdY+L-|2JjMUwut- zB&(^a?SdevEe<I8BXpdL#6#Ie{<sE&f6N)QNsm4MZz})h%U|czzvs(8Uu*xp`O-vU z$;0_S^$Pd0fED%MSzHzV>Rr`;it@oE!sPjwPuv=}M>ByDIC%QOcR}TFszgYFn9EV< zI0Fez0Tj<a0HY;x<!_JMd?qGW#lNchUf|Hy(_L+*oS;bpE~#KjC>Xe{()jfeezofo zSQS~vHi6j!^kQEDTRhU!Umso9E^_VrH<|x>-pEwEvx6M#a%E#*P}&7q_pDWUI^>It z78LQaSBIA}TKdzEUl#pzy0Dt5Y69-J2HdO)BW-Oi5Gu`001Nu;)qBXwQL`aovu$}l zf<+6dil*5diCKjABG;m?ZwZX?G$u=qB~XqGjW;()-)!<WASg&Hjds+1e>kzdy`olL zS{8Y}URxog>eCfp<FLb^Z3Y@Ul6x0~jijkX?t((oe|^lCt;HYnk+VT@Rx1)GH(nAg z-i$YqzMoCz+y$w9xB71HPsmmGd@Ms*@L+b?lIRFaO7{-qik6&j*Ei^1Q{i*fd9-pI z(#q@^Gpt#7Eh$L1V%VQgcwfx-iRY@|Uz_^9<!St+b!|>zL1{sx0%M>*ozf@$+og_+ zonZ#yqr%iL%!f|vd_ED8Rsjc7Y)zprdB9vW#><(M>r`rqQq+*0VSJep#J=0KC3N%Q za|N&9p|`~>e$yvz9bhA0rT}_y2z2WxlqM7T%TtX&Bm5|Yhs-+ag*r~gs1qmdsOomO z9=|gBT}^o2Lcv8czt5Ha!#SyaLLr1j$J|mCF~#Fb6Dn2O878L^c{nvroog+UkLJ_- z!R};3_XYDGFJZwD(AMOC$b^E_*>`Lyzb*aW9=HEZ=6@|E|3~sB8i5oe!^cKUCm14- z6DIF$oCHGKQ0t?&FFiMH3q*vs9FvG4m~LwXD%QAb72y<e93gGGzA8b&qoHW=*b7R1 zck3lC*McG@Do{fUv>W(|X6g&8T~KFtD#^Op*5?2XYr)c|P;h8|RIzWV`XY0@8~<Y| z%U#gbLahhUOS68*a}R%(T8@5pO2zro@tI$@lHYa~vw|?F4Ths1c+-7F<nd1@F*LA+ z^!2UIRqBsy9d#=dcu-MYSyizz&^Oq7Y`@{DrEftw@ss&Lf(Oxw=Adfb2{?#iNmTXW z?Ni_-IM^?*4ZSH?v%fjG;L%?5?bD?;sT=+eg1?_>{j_C#iylezBNxyR$Q3Ch#T@+e zgYu+o6JSxG(Fg%<J&0NLpFFafeI&G#KBv}H+DSv|Zx)Fak6PS}YCoZ~qH)Y7s#v8Q z{1hjY?$bpRC!%~I!XADb{zOZ%`sR3IMh_Y#RX*sJ=csD&fGx|codxrHY3pT!`Va5a zpkWYb@wfH&+vAw3G%0-}wwNUA7c+sc^D>vLM=PTibCkkAU_V;S#<U!1KRo{`RCDsS z3ZgvmJih|s!B|KOIfqoWkwrEcM{)!2qIzVjaa&cz@g%RRp0nilHZ$>k{bvi7I+8wV zpL)LEv1x_J{TXN-{)_1PO^{iRZ(wpg3V=v&B#oe_#i{d&=ww@W*Sqj<N-<c@uf&A9 zrIr4;j)TsfPjf?oL1dIF&4bpvrzmxu=q>fkEH5hdzCPd}^U(JyS7ALz3?d?Wf7nMx z=EG4@A)wXe|EOh+OXLLNWTVfCUR6<`JGOV4D$<U@`%Undon38Q%8%3w6gOkRXO5A~ zH0oiY?t-U?R+wOj=)l}41G@UVs{pUGsZl99;+~~&y|^>Sp^p3hKRlZ5oGGXhxCxR& z{)ZBE9ZFBpBzux%VM&H$mx*xMMnk1;PkC&nMXfj26H67oF?WN<>`phG+udUSec1VL zbPz1~`B=&_@sGzpj-|Ue>i4uf?~s`t)Q9v%Bj_+y&Gw0u!EayjFmc=D!0%VkX6D6q zxUrj1UA4N#I}*75d}PIT2&B=4iv6zZRA?Adx*>^L@=1#<1kO_~>s)64ru2UopX{2T zHn1&W{#_8Y5-{1hZLkFWV^&ChL(7ph8Xyv6=bTyo16hLwOK9v=^Ykf)0aXY5*`g11 z*{Z0%n7dKme43Swlrcq-b&^Rf7M`fO?VB@hQq4!do)?X2^cI0D&Eh$$O~%_#V=oam z7uP3hWN-AOX-enWTNB^1UuL#^q;bqP-~ZaJGu)RzPLePhk7xsY`y}Kn5UvJ$co(#? z{jL~4TM!0O_P|Ww8VY=AK&1Lq1S{z=e?mw#^y8diks03O-d(NbJYip@xEM3HLwvRO zKQ=Ir9ccv9Bj6;v@q$_;h;kn1!{AZotzCx(cbDX3OBYn1=~UfNimDa5hRqk^l6skK zrKc{x56B0FO{=E+Gm1#E8Hxir`;rWp@6xFPZ;x&Bpv`T`FVOKP2tXl~c4QJtVI~#; zo>CunLH;cy9$FAA49I<Rwhlu<zhguGD?}&R!bfGvr6kEQOeFf)ES#>|^kC*h0o#0~ zL2iQbD6cW6*MioG@XQCe<Gu%-j6W9Aouemd)Pi{^cgS``_!lUvEB?i(G)j+H5V)2n zF|Td6%&fq6F+9Ubv+L|aP$Sbx36tAVVb49@u@%o2fc}g({m=b>evuYJWcxD{{z?Do zfbsgz6D8MAx`)!F5!cDGlgoh@c)lzjt<Cnxso&0zvlD6G3^+=kvfZsIWNV_>eD}rM zqY}B7I`*mR5;YokK}XU26dod5eX{$Qd#LsIYRPe`<n_wv!JT>*I;^GbF;0aW+Y@qE zzFv<v5VkxaC`PyVUo>F<$yvY(bd?en2}br~J_2h^IBu))y~-S%cvJSpsB-L$K}X;G zYN3Kh_qnZwZhrcH$}5}Cz~s%zus6p0-k$a{!XY^zCn<U)cOV1F5wgAxnQwx;w4Nuk z(aICQSvR2=J<Sl#6o=#WKM<fvKlx5^23qV;OtKtzT-+Lq^atYRAD;-nvU#cd{--K0 zJ3O<No#Y9L9o3OXN)OJ{gw24|Y1WgX5!C{utCNw7n`4;3K|)zow&{YtN11b0Yq>aU zsN;L{FEW`PEsV#yhc&pay*gGn$jYj+JQQDd+-O6W9GE#AeX6?Hmq#dnKqgf`(#KBd zi*zPM(rW<1=hHff^9Sn<Y{YlR)b^?06Le!_zA(7N^qIi$A||4A&<+a7hzV(VZ=MEl zF)n2NB4qYJ;@~`z5zG9X{R`j-5qI|xL)7<M3|KrD<{|!t5~>6+^5o8XQ4%D*M#>w? z-~i0{T8<|NCO;Or^g8BijI1Dq@WdF=tT;u35^;f)<LR_+h#>JjiZv=TQouuyU@v`3 z&7pV%5+l4KrR{GRMYBG^oB~ta2*)=>a^vTf_#5|xRaGce>Z`qsFsjZ`vG9$X*3@~N z_UZL4Q(L}cf|4@Jjf3&2)Xa^&j4%!E(4ixy!*I{MafY4SeMzdpq8*<R&7o~D=_C=Q z{O&VibXZk!V$67RIOZ{1xQ`fMw)t3bXUJT({uu)5SQavM-*=+VE!n`zkgk2-RZ?r% z)6;-J?vK1t-08zOLbkqsl{G1qoYR=-9jg*5b+oK<<cCv3#q5XYAeF=O(r96FQJa2E zUoM~}Dd4XCFn>j)vLXV>n2XyFx^~6nU?pF}y&#<i0{tZ4d%$FfgQ7n^WFm!t?x>P; zZj{A)=37XbREr)~*pyc&zHXq{$<!uvwo1?>_`a5KCqs9MQ4?688jz(EMzu&V;(R9> zkfkizjTYeEr#U7vj@4&A32*jAUlzFDlCMI~JXUCM|DY4&M5HeS5MxB#80NXpu>}ku z#vqyFaJIQS2H)dxqr5hPUILn9QEApVj*z2QZWm<>@qtK}z}8(B`+%F(m<d@bJBkmu zth4q<Zh31^WMDy<wJQkgUC@gZKnF9m4)UUMkAaC55Fpri7b6KpJLmzZDiXdU>wb_! zcR|-gn~*>lLO2b6ivjbm&?E&mE$@laI3m0m&-xTtoYH=JB%{D}bBR9fq-AC)#*6D2 zuhhGX>>_IE&Pf><zj(Xf0|;ZfC&d)WvSZ-D3(J5M!)ALJ)o&Zdxr&r-2JkbaG<Ba7 zwt2JOaNqw`oYY6nrc0`~7zWmip8pB%-o5f~y}u*~XplUBtN7n}@K3tGZ~0##>f_`y z@QUI<O+!A!@*<8AjcQ>>OzbR9(6BqIgZ{gqfb>#$5qxL_`nt<b?{n+5TBpG26UUgN z*8d5Z|EdA`PY*KuEx7CfQ~+ax8Ypf8GpKSC^&9%`;DVsVG^|o8a{L3qXfa`96E>}I z=X<_<)ETprTR{wT!rqtv-{kmzb{N0V@bHy~z%B<}WRj%?#0{7uTQaqGLHAXVb4AtK z8>l{r9Kwo<qPPHgsi5@L`nDfEB5y|bjhePCGiO$ww|e}wYvIW`Gq!PcNbhV4JD|&M zkX9M*3mzl5kI9D0M&qsnnI~YS9){YUzLyF$m0I#yzV-_-r=M6+xy8_I0VO8*m)5v% z+(?=TsPP<M-ak)cujpKw^4->vXR2z;kz^#4?1X;4bDl${I^g8B&}m*@x>fekY@*67 zvNO?W<2evsoh<-;iizk|N8>N-Q+$x17f-9d(s#$pPt{rLtk#cU>6rFtxs$cWc|b<h z3qGRNBw2s97piJI3LcUA#DS1=$YL4TOOc|L)aH0!1|0J}eVai!|1V8vBm!P!Umk}u znvogmp+!qonH~?<t+dwK42Vya^ra>rr)eHB<@oGe@1Xw~&@9drL<k-u+mfIgpUChD zFxcax7cwA4k!Ur1^wf}U46LWLt)wz!TEJxKgw3~L;j^ACM>OdJ&VYoTm19xtCBlK@ zfdnYAZ44Y#$^<x!39<?D11V|-S9G)#ywyh+2p4Kf{Tux8>}t#|)mlQl8D9^WghY1Y z+3kcHQ#b>^Scp<zyHZW-l@<=Gy*FoLXt&}YdVY$LZsDOK$MgO7)Vq4P)5l@F^Ju!x zd&F{I=)+}BGE%*K*s`bxU#C>|)ody#MB4e*QBmQ^n3i{Jk?YIM2zIjhB-U3@hWdtL z>XJ?FX`%@gmRAi-Fx42s*=Joegyxl%3$+#6#Qe|7CSEx-%*^3`ef#Rhn*`d?ga&K? zrue`c&OtkOzXsar%f1KwQ@<8zaQW?&yN}-Rna{uY2<w;DNfWeFP7xyVEN12c9^~KY zm_lnV$Q~TTiai{z{Sj(vc%od`nts*%Es|YTgDpsu^Z;I55Tb8ZQ6AA@x}H#XWl7uS zP4Y?GfQp)QBR8unZoUmKqHXBr5-E-~WJzKf5VTd4nt`zUKz>xWZ1LSS68uOW-e!Mf zM7rtChta%jtvA{qLR$}95;~B?JM6Lq{XbdPUbQ5*jHfIbb>TaT<gnSqDsx3b!15vQ z__+e}lppPE8`CY!ON-s{44(H*g%d$5_SWO@`p7^iu)93e#_aQW(p<mE=$vgEV9Lum zs`Ys2YUSH{N%ehEH$Dh&t*%a0M#He5JdV`iIne?L5YamPDY?}A=qBLLCH%Q0`D*>S zvLyKD+fi)H{*(J+eR&_SUoNQ3N1r6eZ!FwE*HDwI^_^YPs<iG&#dv6qc8x0eCmy^o zE*B#p&-q$J%i?87M>)x0{q3EnM^VMFpm&j)Rf)8gTja}zeJhjSR`XBNne9KHv6_iL zHJ<bh5mQmL?;F|f%jgjxa6<Wq_OTHe77kUlF$L7EfDT1prZ<c0C%Mk5IzNA}7T)S8 zq~dfoWCwzN3>^JqfXp>)oZx|lM&w}K@E{WwT8FgVIgdP3KMf?&T3*?t44n;I4eQJL z^XBs}$@~ASWHYeE&M*+pvo$S;Ly3U3@%T})IAFI+z6?2GV<Cy&u>v;Em~VyMl;sQO zRBImT^eDeplC;<FF!!SVx7=&~C-eslWY&#MOOH2x7ydrE?_8(Vn;_7>;h64X==h0F z<gAGOE(oxHwqfG`ZXjt7)2VSe9yhGaoGz?!s)7y&O02qcfSZH~huXvFW(6<y(T;ft z)Imha2|46@u#IMl5WiDRAncX(ru^4#b>?c0Qe{MKj7o2;dC~@stc2O8=s0+E@Z!yd zYn?J_pT3u@m1(Y<O14-<9HH;jk+M0?Zy+6fcv3~1BSh$9B`u$T)L!p@Vu&a}l6lJW zD#7Nu@B@GXt@bjF&t$HX+_X?@As;wS*_3>YmWt63&){}DT<rT`^_xi0tx-d5M5?wk zT+I$CNMTAxTp_y?ryyM)Mi4OH?FG&1Vh-*5_X^8-q*gQ^u#{=`^R^24;pq}T>&KSC zdkeN6T}`CI>=caNkXI)P0(}JSMp=kU<p;jwEY1Nrgo+-AR|*NcY{|na?|1m?KYn0! z^;p!(r0^5|1cC^$65jwhN4y#A>^T8FihvgFxTUAfxw*g)q&UOdWqR4iR={#Rk2IaW z(Q$N~y}4PA8jC?bv5o5+MTHja3|iOdCQ_=XQ5ZWoH{vQeY7$79;!#{Lc(#VtM>5!+ z+cdgx_b5M0tdk#0_)hNgJFUL*{Aco3KZyk2w`<u2S;`Lf)FYWO))%3*+i+is(PWMG zrHWn|MMtv6T3Kb_%?P`J2cg(R^>8N5@M?P5l>jLt+cf#*LEy;0mDrjbQ#Nmg**3&% z$xirm=)%Px!OLK^ZEOQ8xQ+O7JL=Us<+E2ES{(O#7cpigeSHuwE_{g@wn=P0;h-~? zeI9TR+XX!ayp4`5x#sME**rq85(C2Txn@UlJ=%6I?vlx~7ryQ8>-OZEsoRv^i?Zym z%u2nVOMakqD7gUB?<W%+s=PNSw8@25$uYnL0q`ziHG+jUZbv?^w#)K3>=k)Z)xbHO zBja@&y>@t;koJ#?bA}x8!Y|6*ZhI9p$qw$Rsq)wWtw;I>meYb<{UUX)k_+sJ?%ikb zWTPubqUeC@{h{Y4!XCUjY&%GE+uGnCZSE|ITZycP56;vHs*oeWbx=+ZqXsw|!pPGB zr*zausL&KfI25JbvSA&e#oa0<@<J%JN%L+g?G&va!b)^(FfH&G<RoSVp@lq>v|s9E zo01(G9dIy1^VV`s)u@kEtRZKyt(y0_vNEfIoF}xUnt-_FsqIB?`;>t+44|NU2n!M9 zB4c(QQjCGzqAYcTI^*q}jPI(-z3g&KKQL<L%>;jul)8EN+=19LC*&63@431WA9e0* zQiO4nc9fPUX6-WZ6g{LrVdco6A3Onb3@t;p9&09SrFx_{bi(JskH}sk;c)?P>_i9z z!s1|TICo98+<H`<k}|i&`FoAtibUi_3DLSSGU!4=o!2s}9qk|~0`=HY{mq2HNn7d0 z5B7OdB7$mnGa@n`9D6^;`!RDlrvO@u>4JB5S9^-IQZQVo1xFR|ZCyC`0={h1?;Q|_ z2O&q?`Qmei9HuHRRkoeH8Ca6)Vs+pfv!OA4kv5+v`-Z{EKnddq5!1~ZRb()dcnUcv z?&^7;n4G<0ZS(>?z&%O=4+tJ;SuHOL+iye<A)pjRgmLn;g7Q((gPE;;RfwHuD4j&u z9dgVSYDG1V^Mov$jfD!4EsQ2O<dHw2qH|W4#2U3>DxUxTe&Lk|S`WllTm=$9h;LUA z%`<Qogj*+sl{N@2x%tL9b-_XcN4W|GPIPPG2OF0iX0+{?Vd7J#BF<HRAMIGEcVcB# z3|DqTS#Q3N!w77#k0z{9o}qF}j5^?dxbz3&kLOljrN#i~Fh^rz=*vEZWQU<e_dZiU zZrdX+EV&$&O6q669JBT(`|n&fquivvRy86?ZfGVVdF)i@`VRK>>8cv(2s^h<Fm@d) zk!zX$Aa?o7IU&W^f`=0YY0o?fE8ZJ2j7l+XefXJ8uw}Q=kC{fce$+Usj_vuK%_gQu zZaX(5H|2b0+eht*ja~=Cl^MN=i*-u+Ct)(Sa)<L(+cPh_IlxNhu{(M|po9a&?&U9n z;ysijJRykl$CrcG<6q$QwN94>_v}DK#6qm3rXqfb9g?ejweMqDNL=!;JK^Iqd|71X zM$z}%Ju#uP%L|#6?!91dU3em<NV(o#5|C1h7m1r%<iPCtBKoTDGVP7;Q!keGv~aHv zUl)98znE&e6cfSlu7IDKi+PCJPf_izK0XHXzk>_bh(w>Va~rfse28-t)jVOy=I$bV zW9F%|`Xxu5p(_5$?{6i%Amv(!5saK#7LQCz`3hz>vVZ$J*>5uP3p@yjsD*GDD+qWN zkms3&;OZ)AiZQp-9zQT$=!U8c=^pc&=5-Tuhq{(4YaD7>r9jzK8Fu7oIQU`8k)3mP z<%Vu43d$sMXk9Q;di_Ib$Z)m$Qc-pIkpTUGhi&wA?C87q<#aN=M-uRM=$7TaHlqjP z=4vf#Mo8DiDhy5s+q*5fnx#iYmEs4zCUKDk7@HjL+X!Xi+rB<QzRt#m4MU^wDTfql z7;}I=X6XXjQ=<RS0YmC*!b9@a9c^GHyshTRA0-EVxo|Mb>f^wm`8NkswW>GIl{l_9 zFUc#d>#2kI+8o_llWGbwd+fR-7Y2sj*CeO-_nMWLtH*4Wlx$&S7Ot-OyuwXk>rev+ zpf-6&^vR7o_e~d+edKV<GVK8g7e^XHJf3x&zjN+>S%By*UVpau;asKTndMc)QF>|` z!UYI?a-Q6qJ3tPZY+RwJkVG~Es+IH1&wNg?5s5z}{m>IEDI0CQy6^k}w+{lBxOEir zHgc2Wt4i_~+!7oigv%ASr354ft^?iz7pxggl~wE<utjr`FFKPL>d~S&=(wlt1+V$Q zhP>?I65O3IUC#pMpe4&Qn)+>gdP7&1AJ!gRZg3m~H)FV<50QR-XvSV#9ZiDVC8fkO zSaP>Pq0P=Fxzy`}*1clu_@SG`E1U+s&+lq$XH_~|Sgd#BfFy6@F6UGR)@(5<8b(>^ z8guwOZZg&i@Z+@(g?9;2gV0g~wn(0*RbP3NMxM-T+xsYZ<IapYpStlt0A*ikZ;47^ zc(Q{fjEn;*tYKzE<uF$mi$@){7SF1xlBP&$yI)jRm~C-&wY2IwAH8A2zEz=si+snm z8VnMo&QtKiam#ZnHEU&!J)yA`$UK50y#5;%oGW{ToN$%s7^nz6N=_t6DI^PynpRZ> zcJ!zEYge($-vrI>6Me^*FFE$O$Be>7JkZz>E>%QEUUzpFdGjVuNI2S4IORRsidw~j z4@IaDm&bH6{7;cl6E#rS_x)wK3syPaA7^Z_x`k|WqvegdP9JBp3Y4WOL#5w{y_V-B ziUx<VHPICtU|)n~8saiBA_mCnqEKXeVXD%6P_?s~m3Un~!t#~m^R@u9h|3JA`Dthu z^30AFMS_~JNI8RIPp>NOQEag)E{ks+*5UqC7Noq)VXVpJ%Z(^;NqNU=*FjO<aqtkj zMM4HrL?=-PJc`FVNw&AU9jmq8XC^0qUvhfIsAubTopR5ZHC@=s{oxd!){4CsGJ3r< z(3&?3g4+ct&w8h-A*{1C#<66Fywf8TtqEZq%x|j3T!bIG$Wz9{=3u@G#4u2CaTq)$ zbkfvEeJ}>Nu_#0NuEB%&m+O)H$tplv22NMl;e6LXVy~}-X@AS=Lx!lS$fxzT0$kwN zTnx`DbKfVzYIZ>fadYJUaz;_%ZBvpK4U@fbk$bCfM}WLj59Wf}Vg^;3Q&>NH<p~)$ zGL^^*!|0zkTPZvdYCd6PsKpWW;`rqzbr$b&k_h^w?2C0rvVr#9Q7O3tJ%Cr>CN=DO z$3ENYBJiX9$>#HPe%&8GPKCs8YEkDWLaK)mi*D;Pa8_VKi%z_~Fr30t;}xuW{sT5+ zp1qAZe$~`s$nsnGWSl`&*ck!XeJ8tdwgLUC<QyRLOV0$#pC;`w>L{?F{ccV<W|-@Y zsfI|~Q908uylJPrI}P-m`1-R>)o|90(0G0D<z;SCp8&Nqud3|EK-yP<b#LwluU?nt z?XxZs8PU&lZ?@O(w8BLNXo3q{r9FBN7qh51aH1Pv?HRKLO&H~hEr|`_(p3ZP{R7ff z;{xgFb>G^j#&v6-mbTWV16;b74v0vLfVlX@6#o9x{=IU0pV+;Mz~^L|-(UU_IWecL zdEwUnuAhDs*`on`825+*pFh?9xd6!KUDvNA|Bgq~)|Tau)(%F;{!@s1U;^S_b^a6d zzjtuIU`cFNo(KpDWLEU+T??0lcNZu~yLBADdh`29FRvuKpg!Y=0R|GJuVIIailPq* z(L32~%}_r;$6n*vRh3CsZ#oWa-ULGHC4}TtK1{fal$gM!+FXPlP-8*YZ#=mqXgal_ zfa%iVW@UP#*{#SQ8i|~rX+;%9E~byc2kgE~s&#R@f;k^agy%e3dL2J~xuXFqz90Pg z4G@5!s3CFt!&+KM-;*cq56}396hT!nW66~Xfn!9r$m-K8!jLoF2zV~K#!-l^j4|VN z(9~;A<D65ESv{n}_I*_x1D68obS851#$9TvtPlIFo{4}gr>$G%Xhu+|?Z-#X_%A(L zbixJI@0XKXPfvUauTR{mnm{f(?_5AwkgfqPHwFImiBrQCYQ4S5BfWkj8sSwAkT$6+ z@Yn-wzC)MH#z3lp8>LhfK?-Uq?KOt|B74Z+AVEi|@?na9Mob&pSIaz<a|O~UwY2!H zzm4;eO|z!j+K))r^oWAiHUpaa=1o1joLKa_G-NZKF#inKAbX6$<2WE{Lb{EMw;rFG z2@nZjlkPngJ!)skY}m!x&3R4JZ<@`&X*`6oO?H>YzAVM|!;DDn-`(UOF$F`y#me!R zxiXtFx#*D*;b5K=Es$G5r33UHY`|KNBtfp%B}&PgmwrKe!Ai{X-#P^Cd)qISHZ=WZ z7;VwFly$l$YXVyf1@!gDR8ONB$*6brxpQ)wF<1Gl{m0@}{XMMwqX7>k6=N$c{WxL6 z`2uOJ+w+4vdBW>UTPbC;5s!4Mg%}^F;o66?5+o?F1n47KezFs-3tn7a|8<#D=P5bb zT#L}<r)t<@Cnd2Wc3xKqeMD$_uz%-0I+eJNt2G3~yPf0VvOqwSPOr8Pci>$)f8k4s zAe{jip8N1E85&5EU||iopVx0ZBD;K;dp6^b;N#*)hk8*{?U2&#-!_Ax@%Q*rv`hxi z1nxh2`h<^t68h1)4#OC(9?C>fCoweEFkA`gX+(>MoO2ae_vD+2IFWi==zRFuLhZ)} zGPjSm@a4`8N{rf5=f!>QP-)lSC%La^FWF<TusKXD=6-M*<P<e6l~dfABy>mR$ywhI zXCrU2O#yc;wP)AbA<Y~5-42U{Zobmv`bH^J-$C82<xj30$377X=YfPA@MhszY$3|v zCy3-#+i~N(>Zklh7hawnzQ#8loEkMaBEo*QQsv;ou;`mw9W3rB7BruS&^X)6ylN)m zvJS1KlOUXEoJX0<8r<YmC_27c9pSEWr0tkB9bHIA$dookLiPu4JPW^>d@kf8qHh7> zh)IG`UITZ%2g2W?K4BWorPG*f@Z9+)S7LtX@_m=aJt8z6JjXi+_Xgl!DPami-Y?*r zMROO)4cT_iz4je&%QbW>kDvNtFn7#3Uz2X1`m@8Ez&rLsyl(>JaIcRbY&^q}&F)#1 zUAY^^wqn1I*0jmDp$7E*>1+Ch#-|UCRB15`I!Z0B3!Go0UOBX6@1UC&FMeC~iTYTM zkmLjTPv=gAaol`ga{E-N89FkaJp>asL4$p(i;t_xD8u1QTH)S_TLnU^%8iwdja8_D zx}8SB6Xf-A4<A1G>}ll`*~3CjkFEV$IT`|p0k8ZDg$9k$JPu2e!QF}l9CJuw2WCHw z1>)~L;e}<!BJWK)vfDBc+@4)hXWS65j+%-vxi^-sRhj_ur(fJo>)8NNqaUw#6=;l! zq3qi%q~8KDlIAAd=kfPF!_lKdhab~{WI%L%Xc?al<U!i`s~H+3(fSso(sSBFlrw#g z$?#yxxd=P!@H!&vaNJi0K2W%bN=sW6rnP$ky+oDCPuae9&*lUEb{;Yjo%A>l=Q;%M z*adN$Ag**G?U$be7SKtghdmRJVq2wR=z#Q^oxOyP()nCX7TG8%|7t2q>@hogR9_8d zgSWaHh-Sk}>x3Nn{>GkSmICc9$hV(9VXlO7GFprtqqK8uwWeGeK~3;&)72u2po;|J z=It(5K#@Y0AH^DX%H+%1yV5B(lw)iv!_(hPe~wm<c4}?@5%-2`<s&PLY$JR?c?{fy z<gA^93X&hd3NCX568aL1lz2F=8Qs&e=ra!Ub8J=q;os`J?^JP?D7XDGvN?|CT)cw& z+U>JSu=6Ce0}j(qhwLt5lYI6qq_C102qe<<POZyT7e~Q_jLWWyRR=?~`!8t`TQ#L} zm7qJ124;k5k+?Qp1PHhlo-MD$E@q7-Gij6Tduvis>TjjY3uxsMGNf*_^F<}z<%>+b z{dGpr^347U4546XhUSk9#E{PdnyD?pppsNx*<x36c~(3|y7I7*ODx{hhfx-DHIQ`h zO%@)Q>DtwDE?y2XZCQv5pFfvd^T}OzWxOYo5<c~H0i`ZfV@57NCP%}A>12GxqSm0C zlAq~ZBreC%Z-SyP5n>kLSI9@k8yjgAnUn|QoYc8K6xVpq>8_d2-409B#Syjn77s!c zU#FG%7(TQqKlqiUoO|dIo~8OozQ^(va>7U3G$beaAybYwWVt==28V^139ikEqijqr z!BCPS{GeB_`k3s7Ct$E-1>DbF-v!A{SP{-AE&!`8dIj)Q(Q?}Yzs1v)Muy}05Olee z+vIzb;KP&yWG`yw2dx9)0Wa)5PZ}S~Ex+(GMw?k&(!Nc8jD6oZ-U0@)EVLHS>C=Mb zc}K=icx6}1O|S*KKT^H=en`O0E?X_PY_pDwQ;MVN`IIsDr6F$?pP+g8C)8qLjz|1X zqjC!4Wk}YOgapGb3yZ7wR7r)dwo;RC?E4S4^y}YiX<c!dzpyvYv@a3egJ^<t7VL$A z^$@A*c)+l>=XyzejM+E#v$;3F74Dl#RSP<mdh3?>^!tP=Cg_y*^jnLayyV=Ho|}Um zx=-xg9e*s_Ul`n8$gfNvLx$f8Y{m}w)jl%{MN4?3*JF;O1qS8~t@0EJ7X-}j%bPKI zO8w9m)^7*PPNQddrNb-47f{8e$+7q{Ea7~Pf`mgl;0|aZ{3<fpGQRfBcfr~aaiI#` zISc=Bc)utR563Y`0n0fj$h7LPrlmT>RgFPcmzR4tzA%;%jTQ<}379(f_{>&--K--U z<t!2EUu`{s47zaewY2KxNg*3oyC|u=jTlS5km1)O-)swcBcqJm4O^zmGtV^<mY3}o z+8k?DYSzoLE^3aLc>1tl0Co_Yzv5Z%7nf7fJ_^@cj@ilvJQW30(8_Hk#VNZWF;GnZ zGovCCiazx%`Ur)!&qss;`P`LCWC*O@dLVe*)UjBxV~klY<ACMa*2R#0LLv#mbcyTQ z$N&g^2`Rsv#K|JU5TvSM&5HV9mJsFNn1^1LQ3`(D1RB09@h+-Xc%Owm!(uk^AZ}b_ z=1UHoF$2QzT}ExHDJ8+^k=18T;nYEA<rGbu9&mLNm+S~{zNWZ@SwbI{l!SZ&TV7F$ z3C-@<sO;$$_pw`Ag+Pdr_1NRIL8A+@jhI0r{lH7)M?;D%Y+uUUPvm`bpsRJw?l}C0 z=_LJW28I_$&Ead6Bx%b|2ce=vWpj@gcq8}ir$2dRGHhQDHwi^iaoLatEa2RCL2=%u zz;s_&z%zL$*ZUkMMkI~Fm<On#SG%S4TKn+dqG2+liv8JWb?e8nd=G^;&*$WE&9$=k zY>%J|wtMd1QzmdtQSH2BWf$}Ty$ty@p7JK^xld&l?QOnYEcrHa97!9rI7hCPU8cwq z1L($eSc1R7K5EB*x4!i3!u~GKa}OSzpNuSn97iY-jp`f^8Y3=}z2Aq6)^Fd9H64;F zYAf!D%0!OdSUxE-5}|zr+ns}uCmW1`{Zarslb6tObgY86FjlcgZMxlQS#{=Vv}$p6 zw2EA_R5J*44rHYB=%c}ufqh<$1u6#V8GFAMa+3h>L3ZFJQ>-SCkr4iF8f3tyL-0(G ze;7opGXq`pJztcg)=MrfHM&kg;;hUU_{r5U<_BRE(@A6im=W_(b}(yfjGC<bl<kN| z>sCp-<JIy+?TpWzKMZUgghLFYQV(?<G(2#XOWLn+VU1^E^9$-FMo>ZXDO(8bSUD~) z&$y<l!XsaLAV}(7%26##mQwcvPP`yqcF8F6<sC7^elkanxrg6)mjC^pb*~b+n#EPD z=M(nz+tJfn{EChfjOA_z7sm`0_`}kbE4?q|jAC0mu-+HU6Ak+0L*~&lcJTGN{De*- z9I#{*M2jNS$a%;ed7p2ahL!3$8N!Yek=zyp8wHz(prAZU%V3As>=)RxBv)Ros4atw z%fG+^fPnOy0R<(Ij27tP+y&005j&IWHxDu29klLSC9l?HyY<%`)DlTf9*Q7kEgs(y zM_eIUWKSY_J-|Cx@APGPB~?WdE?$?FH`-}0lFU%UrhO3X?MH}EW1!af7theC1qdl( z!G;Je+osP*o?5x_qozu};)WSk%&f>S4n*pU<5RqeRqsPP0%?&Q+ei8(0RIMoxw&l& z*`6&*(J=_a0o-w#BO3${k&Pxqhzzyfoa<}t`tXoz7v5J1UsCAXxaz)i*XrIXsyZyQ z{Gf+j5j7XWO17>Af10iui}d3+)4n&>W0hS`51Zzc%AfpP5pUQcLo@a~785h{DZh<& zylRIB*`)$R7eeO&J8Rq;xDNIl!{~Fs!~GT!RX1|rslDj!<a~#_cA=rA@rGg7w3PL( z_lv!=XK*aj$~Fq~ZWqp79?5!&kJ~VZa%PH%k6tJBnzrPqc$QWHcZ}x>1QA|jCrt6_ zJW}OFL&oZh-l=UDtAh8xd7r{rI3(kFMB$}I@v`I(_Cx0=735H246F+<^cgsgK69;) z#?|BM5o58+HxV(4GkIk$R&ngHd1#TwIq~Ry^tjp5ghj?3VcStj81E1QQnq8JU)p+U zlSj`uH}{EVv8N6ie{JB?-FxZ>A%LgMON87ch86Fj(R9ZgU7Y?%Yijt8xw>$O8iP}7 z%KR7kDBs+D^z1BVi8>uj7B*E+H%;9zrXvM6r=3jIVb?MWaHColL$-F?>(C3xu{0p` z<>fS=cDDo?xaaAJY~vl*>BJOPZ$-{gg7wZdlZdv(F*(?YBZWRUJ}NVR(C9{%W>*5K zavNNbs1g**@K!~W`xacux%o~uae3ne^)>n^`F&554=3@yOy2ulQ1jdJ1b3shNA5f< z%<l4EgXV{G{D-rI*|o4;K4*!@K*$LXBLgBVG!KGLTex*GZv^v3*y6D~v;1;!6i%q} z6J$N*l<M~^`t`WYxlD{8My4q?>(Q&2LFf#U?1*v+C$j<WsnBsCfMF$!7OeqSpbf-` z1|Y16M<%T13wxfj@ed7}qn$oAXE$^PzumZP#RnlT!|iP<mEMex@YmOGOYO(OdQOD& z4T=Z{jr;Fce_kl;e~UpFb!Qa?*z_sDBT$D`HHhUFF<<@6%EY<G-37c}e~lh>dY8W} zV_-NG^?85BSmJC`pMLIwNS1tDjffZT-Px6%+ZlwVmGb#Kmr2x6jtS8{&_~RtBM~1t zDpn#O%dI8yEoQPMlo6aD{PNX)jnjSiJwTiud04Xe>Fn`=KTIOnN`hS`!0eAMm{|w{ zw|5g>U9LwCmNUf*t6@aEl+Lb#Oh}PqxX5bLJfDs`q7e$S)t=yhb;e1O$XF?xwG`Zn zeQUM;^ca1{{(Aa@9cMrVa3GlNQJ;1cGXicqS-?oXcDnd|wzjy#vdnUP*Art;j^Vab z>bGM`pAT24XaYIdCET9q%zm_>cHbK68-$>QJZ~3=MdM-|lCYF@bEE7jBPJ&xm|P^Q z$k4%LAKq4Bq333gG@L2(fRAUhe&-1?auNvR8_W<|;>cqsoj;^mPq0%c<~Ua5tIhpD z=LMEN9rT37J4v)2Mz2jeB^z2~n+lj-zaG3`@y2TF@g2Bo<CNn{MNIOJ^2SJnWHU{K z){HF1G-8gLc}zB?aIY%1_4{DmU2Hww@8WbV9fn@mNnJ6^=PKUccuCLc9RpJqQd|~| zz-I;;#a|z(fYaz$jkYnLw|X`%_EjIB01xH>MqR25Qb%St`2zyWOz)~@i)!W@&V7#% z@ct529>JUZ<&_IOqx1)onMOxBMOO6AnFn5u&)Sc-xb_o-YL33Yd@SdF*ee?0b=oJ^ z#y7Xm7D`PeC9b!2>)equ=Y&7e46f?XUs^&)5+`Y}<w*u==moa*_Z}ZT5Wv2D;1E$M z@-83hix1oPb;)ME!`a_6k9IW-A5(%L9&jGIc!6sBo}xrFjGP>?o^VTkA5^U~!DQ#9 zy?|9Rl&`fziK~{m@Y`c=--3~jys)@Ov>fx_tjF+C)W``F?G_ik8lmj2il3t=#1eT4 zB@4G#7IV8yk|5WuWosSe8O>hT%UEbr<DjUsaS0<Tl*Gkmb8#0m1?3x#r)iT561w(u zdz}A&R`PXI%>i>AThdiph%!wNIC1T{x!*zD0*+qcyP()mBu(a%s{4+m&o?lOs)p=K zgdGWdy=OhtedT?b3BSAV<yx(3f9E5t<E>GnH9d7cvP8gFm<i79A!%HKE3q59!4rR0 zq=D03#Po4gbVknpLmijt1=S@M)dr^E#5{-?vP&J<#?&x>c+5+}LcnXfn3X8UT~H&g zs!Ppx<|VSOn4AjOsR{sDD30!e_Ph>rX+Kx9SXMjAlmcwIhp`iZR4zwK1~BdReCvO! z=T*aHEaL&CNJ4*Vr0ch8uFG7ycPqw#bNR20X#83YSFH<R^8-TG{G}0|-@vK;txp!D zJ#c{J{tGzxZ{Yr3&#Uoo;DUiLSh~o+uwV~e+XRxz_1Ahop+k!T=E;9=27o8~bH#t) zD<Y!+=KZ}LfO)^x{0G3F8v?lZTgATsRQ|1<KLb?QV;}zaR{o4{Z*2XLb%lR|=krU) zELQjU#{%%7dBXkgd;s7Tl9PY#msk59{{VR4VEixa`?;Fy^54+;h$4Ss-p{pg)qltK zOA`R=eyR4aL6iO8LA$-jI1Dit@*6a^pE`uA*_(F0z_k1HYxe<cztpo>-(wqKS9-Yq zSGN6D&#QNjZ2-2PTiF}bpR2hp{~a4}_w)<j*n%;C=l0Lwl71eWJ#fEN{0BPBgyYYH zT0s7#;$P5-Qh?-FKee+5@8^21%TM=Mha%?eh4=m$-mmp=HGAXgi>&*t6=xKCwF*gj zZ-Y`b0+eF7A$0we&&-besPyRLZ18BCr9g4#TBEnQh1|{+WXlQME{FhpuBWBGyD@_t zH>cbHWF?>@k4*w(WhB`Mc?3F}ikvzP=$1kY?9<*AH(>s^g`TTw)41^s*ljWRmR30s zb_lq(@^>elrG@a&^z8S9^2VST-Tw)pIq0}K;%9_4l_4M!m}~_931&Kt8~0P=3qv2? zv2g$hTz^LJ9}#*7cuxG$;lELW{-qngQeyjSU;ai3ke`1r?Po~;Iwin(J?&4FfY=zb zzmmdii~SDDg(OQ0`5P!0*$Dbetpdli=c-p}+}MVHRGqh_93UNDGOQ>jVFymWnz~Q> z=Bwk^<BUk$NGz}gUT5dl-m&oloco0|bli;c3raLF1Z=;z2e9tvrWS^PxD)LEwWJXL z3>EknE&!;S8w!5n9e^s(N5La-!l;fT9YwwlEya}j0Bcf43CJ3vkw7@@?wQixj37-I zqDsIGIXN%=b~GGTaV_Q0o$EC@8=X67pi`d%X+yY%K^W!b$!#2VhIhW?W0S&2jD`6C zX{*D|B3>x^ZEndd=IcfqY*1jr;JXWSBLXk#PAM!uKd8A2N--zFk~UtF<G5~VQ`1rB zwJlr9Cz*|Co?b>T=b)rzuU(Kma_w|BkaA(NyO~y@ObdH9_3f+#F-NTF&NCrwLGnb5 zEsO~N;RU5Az=7@^LjI6M(i9DskT&38!zG>_Z8w))kd%G%Wz^Ob9Aa3OPAl;i*g8B; z>u(sEpzv(*=I-!NfaEy4AfbNDcp6OviQNU!t^iAS(%?J!>ZEc*0k=o-Wg=#yc>q~9 zeP#=9U5=E*JH}U`-*B(2KTD_sZ%IPvcR>X`ly)ywXm}2so8Q&EM~~1gy|#KW|Jb*M z2|4FXL!(ew+T77?p*59X$6-%CIFcm)F?}}=OX4A)fcW@iL%;w4^ULJN{5*ieV(-|1 z^T@BYq2s0$K)d{pbyODu`d`sPZHs~by1MshvHn3zlorB3)64tI@zIo`1<-l{tnM5@ zB-S(mw3h#Xg`9ndoRZiRjKH+QGys&KTi5r*-NLZVf0EKxbAts3Ip&>h3vy3ZqtMgc zc*hMZz^Sddx53GwV{+4dsTpA@l!t&>?`R{msv^W81I6D~&#e$TREzv*9-H`WqdCV0 z*cQ*dUC^iaxLqLjqtS$0B(Qk<@TMs0VgB`%Z_*cHo%Iu|UOH5xr0}!li&l4Z^m0Th zjWjAn30fqPE;!-b@V=dRDzFv!1Lx;}*W{QDaaqW~rbvV|+8m>@>4Fk1!}p=#e$;Eo zrN}8$4cevRT@Y><+;L|xfHf)mTxnV~T(W;Vw?{S|ChIy5__*QWWXo9_;Ffzi=v9@D z0`0MrS@V(8{X8CbjaN=izK8P1VQdS668O_hFFXHec6Je#P`Wh4ZAurz3Ok6lOQPz` zP?1-GK(rT4N_Ro8MiQvVb|4SR;q7mb%P-@2Tj~vCMs`7x-aBuNZ%CgPTa_jNJvUvs zxG}VWBF8)#^64$#lyBRV#%`BNN3LHa^~KTvV-C)1(#nJAMM)VY;81;Jv3VQSgm&zK z_JOxoD}c@R;alLSJtInerm+i3&ncjq*1DJutS{{heV%#wK8^qocnn*G#8yz>`v?y$ z?}D^PO6s7!Ly_wr1!x`{*hwJ5UJ14#$l~da1tQ6hVY|x5XS1D_MNNrYMW^qAq<zXz z+tOk+yP&;0N%+i`mW3P7sk#<eq5x!hOMiK|jV18q$D7hN8MwAL1)bX*P+M#;8+SFv zt{Ej<IN9bGDeDNG_gczu9CV+!(#VB+v<otI=$U+Q?Cz(by=F>RbK#^pU?RYI;5#Dn zF!BZv)A|!2K4$fdGHH2Z6?u!>o8E&<j{1x`Y6_J;uaG&A$Rr@Bdeo7xD7%ivVDyh; zTdByoNPB$vb*Qal@Im6BcX1|mV0Ef?dDXIIV{RMQzKL4Mhc3irtwiKyQT&H!bx?}A zB`tq!%dw@Wir_PaJ)or}30%hwWnhFOeF8V$lA{#FWd{putRFx2qZxm2+?vEg{po=3 zUp)wcmLl`)oReQU&+L6a1dS|3%vl3-*9n+iDtPKq@gp>dJroace$6<)(oF7!{V2*} zxoTCl+WuVqIL(lTh(ab=ciIDnkD8{=&_S2i-(C+eSc^xl3l8K&0*QGvO$4E(GnT`r z5ueZRsCuo^vdEQQUbwauV1@Z2aaG9KjXY1}<TP@d@B!e{3v)?-Cmv+gavyx!sYB<9 z=Mbr4?L758O_=`Sj?suQ+pr(awCMpLvo97#KI`nk*kFTsez28$dYTMTZ!=J@9|V39 z%6BHf<(0Ryk;9*Wd;xCxf%O<q<U-ssE#wYmGsH>ZQ*YK`i57LLdwFsP<)T=dL{2c0 zzXmnZV>yyR&437tq$Im<3<LQYZrmT5s$BtGy0=u<=NhH(CCG-L_kE{!9FRqY*aU`U z`5ov?WZS|h8u=MmB+7x%1z;JhDP)ZzoAIGQ(0l4}v$Zp^Qv#*H3C~3xysd3V!t2kF z$ZzU-By^vNn@(XxwuB?!ousBigmqlsn|fD)2jzm#PF-HH3YipyUA*`)16fku5G>_0 zu<^}@dKW-LxfV9_$3zGO)Lma;_0aVt%(h&P44!t$kA_L3Ihyk#Sj2BK=5490Cj$%O zP(juA)$E7dZ}p$dR=I|urlZ!5&k$rug56VjG#tj}DoSSGW~?%Ww3%m3eMb|@tWE$A z^7VJisD3!V(vU^-bO(kC{%9)_Fc^Ooi&<&}WO00D1?I<bDD^4O>rRK9uXXxi**&r4 z2vnG<^{l4_Yx0fLM{|3plnYX#k!k%K$5l37kt2~$rO`uEGsm`^1T$DuG}k{kU~C<O zHWvK|UgaOlw?8wCsv&{kS$&eV)hfH7lM36iw3nsY`a2{f?i^<AhHIklOd`SY`?}4P zvAaoav$opO_A0y}YAq#wM-8FzDaF`lWZLtIl;_GjzF}i)sxEDL;l?Lwx{nZE2s7mN zSPa_x#igbfIukTkr8o`|+66@}pti{{D(m&oOS>R_=nk+$&CV{QmGb5a@_)2gD_`A= zrc?<k2{7I?hTHU&XYGPw#PG{i7iV;-=`^8!+dlrlAq2Kuy`a49x&6oxL<{>vl7Ysb zGXbS+0<-%ruz7R;5Cr^wquhZ8@);@NPOfX`r1D=<vlQH7qk9*Vw6n+VzMnZdHG<OF z1=ZiC0Adfgr6VtbDfc7Ek5ECopjBXZJMIUh4>0-`Yvi~gBd`VKp;=o?K(@%@B>f3! z6PQ<P97@0}M#^++7pN?Dnpw9!Ft1Y5z#+{x=HSN2x)L(57pc(!8iiX(k}EM`QFj+4 z;yX#JHG*!-^Hc8@B(xKWnWpiFTbrKT*}Gx_R0~il7aJw>GN3itdpr918CTI|VImr^ zrw#du>>t{ZpPpTpp59^r-eG`J;Xn{7;3%U9oSy%93$Z2%&;kLy%OBKnmAXLS)t>Ez z4K->Y`jCea;Fdf?4j;>tZM2Xi8HC-5DD=#|qvi!oewOAa!td(B-*C8~#n;3g^?JHG zbi=eWgBlNYfP;P7__@0OpZ4B6s;TXH9K~zlTCo5sD6t_Z0yc^W9F<EG1wjD?Ay-im z5hF?`<S2+Jh!JUuLIgylMx{wT0s>N$CPsn~dM^njq#oW5HoTvE@Avc8Z>`^Y@Auvx zS&-~~_MSa6duI0RnK@v(+3n5p_`}iE$}Nue8}6-Mn!V}Hd>a|*ho|qn+_VDwjP(12 z!D3vSk&jZKg0V*AaA5rVg~ax|22N+8Ao2$XxRL_3rnam>KR9loPhKYtU9;U0&NZC% zM!mg#*9TQ-pVLd>{w@$-=j|{1<hTEwpFA7p4RB|<pA4ir84cC8uKth1W1v*Hwy4=z z-)|v+f-u1U>SGQx6w!%e>i_+`v2fab`7PXV=Kpvl?s+fT@9-BuPltha2O0FhQoj8V zsYz>?rG>U*xH#+zJ7{#N-vCzbEvD2jzQ4Mf<;NfBytSLengmKW6XC2lMu8c9Hd8_; zA597AXlOkJ`y_va4=jwDh-bL~5(5mUlg8OG;HQMDBdKMfw^<J(EvQ6J?KrMeZgWNk zW+Jwi_Rc&ab5BsPbsXIfJy?Am8W^7v+GTC|>vybH5ox4mJRH4W{j^PFlH~yoC;GMn z6Qagg6R(=@ag~4hJF6HEL}ri+hCAO5%fZei$N&q9G2}1-tdp8(auenjX;h(jO6Zil z9I!rG$!|f`d5x5UEQlkjG}4I-vuJd9C}WW$=1mHU8QZ>YO2`<)jb)pGZFFzjl+eVa z@L1SOAQ=Ah*Q{b#BA!J-L%d2|r-U|yqoQX2XoIDwa2{221lwqzN18p5mAQ(XDItK- z17ed8j3(V+U$}4E*b7OmE?|>A%(3MFh^<{A&bE3Is@@3>33@<DIjewU{Bd~TLk|!S z5mx~r9+M{eFzDUXE(RVXzILjSMEE(#I91^SG)fr&r_q?Y0NTVPZ32ZuLjLBr(9YJ( zgnlq`$ko;<p+IvU4tO<!=KPaEbhKSSZV?U(*xHR-?!#pn@q8a3?GD_}hqJ~2qd#mg zZ?y(OMgN(;JDSshW?R6r{Z#m!6*9OHNTvpZ>In*FqN1LyYgUgQ0DAWn235#L%PIkT z>IAdG$tZ1SD2e9|7JqUC9zXunh{XoX4jTbO^K@7>s@?{KvMa*bW+KDGQ(+_?i2s+? zzH4dZ{_2%Gn$RyW=Z`~ua7Q6D(2R_Gn&lykF^KQ-5gy-{eEV;IPR+<GfTO!Uf@Cl2 zmwY{7kRZdy;l91*T5YsEWTFCS1#S7O@6kl;<ThwLGAWuq(t_o5c?vWtBHIB^0-vkT zxW-8t-fcM!Qj_91CA1E=1g10r$n1>r0KSlwD2Ai<@Shs8w<E&UQ$jB5@Xzc}KE#+3 z8rHx64`zZz7(BXyUd{I;qjxS;W&s8T+)D069#PCzvO@BK5gRWjbr|C-u*LNrythFQ zm`n-ft^_7je>dpt1h~1bV@ha$PU_KrwhRj2!1KaWAigvyxsx=!N)gM?)DiK6J2#WG z!G;P;xBwFrIXx@U+jmo0uIK@~aV&rSg|PTZz#FNKaPNbGRp87JUpZ!*4Vc|6oCToD zlC%co|6ASNk6i}Xc%qt*mqGocwV%KVBt8oscXFlxO?tTym@#%34WB8YE9^Sn6Ww-2 z_DUxaPAlmhMcAt67Sd@|0}Xh?8&9tFvbF>zlHf9kzl9_h#uI5kPic#qKlKo%kY4KY z!WI&_f;`FIclHzBhzzK8KxA)QiD51PGNSe#WneDe&<kaZ>Ky$d8{^N=TOb>AP4I!H zK8{!W7NZNcaAn=BVQ&e=Vk@Yf;x`WOR1~q-PtdmOdQ(-h+^sA0%$Y-XbPkAiiJR?v z>NTIc3avtx6=2*Rm15IXh8DWc?J2?DN=7c+Y!1_NYTB-~@QiasSVhde<abFDThso0 z$)6LG`t~H4`}Fsz|F`{R5UDFr{dow8S3o$XVxAg*t%7`Ww}+KjZ;091C|7L{Y0+6d zPUbPj`=tKYznBg|9nG4deW8&+rMy!UnS@h>97F!A!PzN)r`I^Y#pERyZ)1+#VW*sR z$ZHzu#-lbtUZ=?!Jz16Ld>+^C9~?vIwcg|%CloUtku=Ca{?6%4C|>gctp~=XH>}~D z@_b4t1*3x=K;Y-8xCvJ-sMWw*p;SZ~AN$s$9}B!k$*dJ9Vl<}2_0{%RZFpkBqCZP% z$V=X#-rB3$K(<^n;Z`8QlghDZdX`(z6IpA0TiYW($GsV#_K5KGSD}CY;<obNycKA` z^HiSr-p}<Zh7_iMON2fiRea-rulq(DucB=1tS|>`rDjY;CXsjD7RZa&(BlmL>RJnv zDtFs0qr>&IwCrJ<t2l=n9I|}+!O5X68+P$O>M_)R1nQ`RCned&U@WP7@N1+ZzyIWk z%EL#Qo^rUZOJ&L!%)6MBEcO@la(eZIADTA|kDrK#c<$vuqXmII*pJ@pO#x?B4cs7t z*~?t>a?&iz0PakL;;MT#8czv@A(lpb@epeOY619+@t%@<Tl_ntOV;=sG>-IihT}G> zseaI}aA`5;S!T7vsypDaI6pq#3LXq8hdQYBXyFKOd+1~97K=5a*#q!d3Olc9sM~&9 z5hF`+;yjt3#eg2m{hca8Vz<gc6U7Feie(ke{$XcbNQxcZw0>As9A@;0@UH@ksP=Ra z)X`nf=R-g<_@nn^Q`$<I$S}}|Dv{hZ$0J(m<MKZhu%6TC**prcU!>-PWdltYOU(eR zMSM6VG>{n!w8van_n28FQH5WP9tQ7??g~+??7|gI33bD&@_(;wqN#G7V^#G78PZ3? z!}^<O<GNN`d9PQu4l%IU+<4v!b|X02{q2DkTf0^)bQ`=i3kdUfiG(m7o+u;c&CKYs zX6EK2Lp_S0%EpQTIoCsr0A0XRdP|dYVfJq#95jHP59=H=OHVUAp>s0#$^?4C&?;vP zoC)JM?YwJx=~DCF&ADvS2eX&p=`C<`fIwY3UxnYP!JIxv;EHNC9~aZUd}`pzf$f=J z5cbxrcCT{_cQAL$)^s6h12mliOVS7AZ?y{&Yq3?#$ioECQkeIF<31^`do<W7j8E?B zhSj6~NuLaG*2QEd-0m)}jmG(qy>O%I{i^6esj_f3aIc6c_TWORdvI^DqYAy4i9#qS zbqVrpfm}neZB+H6FZrD<{N|yTU8fSbsu=|-{GKsL8$KjzZo~nsYd9r@VF~M|L*NGb zuiC>yTCClM5!Di}k*wvaKI}h!Q0|uQFMAriUJ7`kZ}<VVOCs_d!HYOqdd@^P3+TQT zJUMZHAu4DM_|Y1SE|<4K2@vBgILnp9TvJBcHhhA5IzSqR5j}QhBZIh_cY&or?=cC` zW&jgz+-PzCic)C<FJ0rTjEAwCuC+-!J3RHCZL?Z9YVk$#TcY26-Q+`2vUL%VNRzgE z!uEn|0g`qdyR?<cY24>6+pVnV$xmLWcKc?Owlg*M&Ya4KrI+yo<Tn$FR7|T>kf&Yz z1=2F!9{a1$wHGlN<w+KD24Zg(FS@np@l*f(I=yeam&ny*RExV=7DJ82Z#~|Jh4NoI zuMG{S4sY}sGCyit>5bVDUKL7riRrqlD*6%IH#R6YzLz{>F4<pD^ZLE(TWbux<km-? zK?1c;*<ss}`KjG{8+KN;tKL)h)#<q+X1?e2zBLgPg-5*nVi!<jgm}_q_SKmyfdDq) zZ`dSP)iM?`!})ROW>$|5YjOJk=4Q)mmKVJFj<cQ9XJ`9-b5mcqZkf$vw%wvXh!6hc z*DV}AYh=j}*Hi*gg&(3ja5rFyA$fVi2A)dp%@)*hQGfy0AbqDrb5TkK+yCjSR+}zW zlM=DQJ#uVp*E7W5rKM1*f)<=Mfu<Kw>LvwQe>Vb7d}j<mdJWs}r0;Y-jMJ$ZPm1)k zK(gw`D|g%3=DZHw;4MoT+54_@>+@m#H>nE?5y^tZJnaU|A9RtsPDaZvY*`)BWXB22 zuWDCXw~OtVt6Za3UOG=#-G6!RD@lu!OoXL3LHfpf0^|~ZOkS{@0$m)rNtGFr7A~x; zUctIjUwrwzt4bC%kCXiSx`quoqw}Jw&K!>oH9Zlz>fsXIifBWSXj>r9lcT?)W!=Wx z#M)sS4K7WRE_{@4Qyy?7rKLxI5nb^o`|L|Z_>$w%s$>$(f;v^^JtCS;3C%*|!LBIM zO$L?k%TQHPn=WsYM>UTJcicwBIz7TmNWN6aXNa(_zd)hxB5wzN<sj!VrAGR3(^$Ih zUR7Ng;aU10MB=RJ=5Mh=c>VhBh5{*8upvV}G%xB<nq8K$mF}8rwl?X93N7cXk-EIU z&(y(UC|jjgv*+1APTSU^z#5?&y-~{s$~}<h8GT+NCNbB7cERk^SYlT5IZ*N7n#S!l z6l|hfQ(?qr9})kINKc~|wU-VZkQ)ei9+PqJ%hC-^yI$yw&7ZZMa9U@VzJ21nq3m8P zJ!k+L16%a249ghYTS%o%2|*ZEJM8OL5XwkuW#+eMtIv;d>-Ym<Y%<z*_*voergFDU zew+UoHLV;zw6=Kl$u%AGh$1{Y7Sv8#Ufo)~(2b_uFfjn{s(O+BQNAY|E8A|=yfKU| zvbe00CR+Dd>QPkfFB8ekz^-@QhWk$g-&`ugd$c1v$7$5D4k9QbK5|oPCGm~9^=uZd zV{fER%Y3ERwa!*XgJ$pL$~!N~U-U0grWY1|S>xjynjB(ooH3m3BCv*t4T>y%CGrfO z3by#p&;_$t#f%lq>x}9$C4Lk^sKtV%ws@DpVRK7{_C$Eo9jS}gPiLlk@9)gJvDA`q zWmRjm5lOL)1k#^H+QNrKo}txPsf<dIHJ{zG4(EP!(ja_#zk487n)H76O63}~??9?c zsBj>%cX{DjMZ%{6Pj~7hkMu<?sOxZ62D!`Aa?H{d`U7d(^a=U4;h$qmBJLjwtM&-k z{UqeT#zlo5tu``i2R0YX;YS9eD;0;N891hQ18_pitIOkC=F{7u4b+8;G<XS3lVjU~ zBf!({><88Ox$alk%BUIkkwFi^J1nSs%m65=YmeBL2^>X$_01XI$qTR^LHWkhtyE-a zE`8&^SD7C($hQjCzRJjVkM%y`Y-y@*-I#gkIlgK~;9bh>y+*?GGVw>-?VHWpF1>88 z_O*x#_i|SbcXWP=J*-y#Og(1v`7`&=H`)1Jp8G&~N7M!BO|jKA7vw9$69zH#`}Wuw zeD)I%d=K5c6cj^at&MT(>DZhWo%z&3I`^Z}T7_=qx;a~iO-x_6oTjc{aOQH;>`R4B z{%jgEo}cI>zn6K1K@3D}vA5I=quU{o%4f5cA8ACzd}1_XwHB_teIp^G;i1;f4vi(c zl4=gem{vGke$((~iLyu0oGrhgVkT_ndoNWn8M%+sw%+%8NA@p(@K&tc1pF!b$Y3Q` z<4a?o=H><`wsWM-$^08un}eQhd=S)ey+D8SeM6zk`zpM*m(*=zhnS7<$n4EU3XJYv zFJNQMz`B-5fU}&>x~zQtLx+r+w?2GYd#v@cZXCg#2eBd<18p*8&62H{p~WdD-0BZA z-lsH_`KjyLpNnuxxq7bPA~RgM79)T8aWptH&N~tC;;w2vD(X5^JC789UHB$Jz3=J= zy5jFq>uSUC!LugKKV(75d)${8ZFs!V6*FSplE{Iw$(By8^EZb)vPkm!Kt&<vkqlxj zj-~HBbCZPs*d#Cd7VAkO^!o$fNoN8V8Kl5prQ&g?BlCKAn{vK!k?1c?A+vJ9!WDk7 z8^G}Z{Ehx*VxqIa33t8%a20%&HNUj3l@_`L<e&jC#(#+49NITl_0^tkJS=Fe{wABL zpb>b;fpGTK-v_`Awui@n+vl4^Qx?nu2_)Jn<9}^6IWiBFYWXpkC~3w(0+0m@`Z*Zp zmsV*Z99;AdK!G%C!Fqme?U@NO4U9#a2{ORWSCId#-6siJ2WGZv{%ozA2AT7HFhByM zRQM}ij_>}rwvbui%rO7wNSGOOm@S@nO$BKK<sf=`zFT%)S5X2yZT1yDeX<x_EF4s- zX0xZwM1iV~1q?|d2@J>=vQHL|fL7YfR;!t<(+~fb&0C;0Sql6OjN}Efvte|3pzC)~ z{@X@$bM~0F5Dw7bDA4QE=;0m4&_f20uQ&(y$HzsFfa^Y<wA#p5x`5h|ACI^NTX;+R z%iY;JF1ZvWxE&Aii4$^yem=(nyJ5zNe**Rb;h8gs)%|5e3ZyOMB$$`>V_wh!)CZd8 zUNGgi4x#oxJLm&4nE_fsSir#yrRv&!YlVM#MxaaoZ*?&KJ^~-uGhotx9N-_?!Sf>1 z5PWN|gNsbba5V*NM&9A8<41(|3m^iHo39}67BmB_e{27EL|E&;Xt)0PJc%AU0B)}) z1REMq>g0sztQtT{^Phua{>L6NJZTR=-fs|7Gyol-UmpU5nAt=9{F5bq=<<9RB_IN} zZ$A&XoQws0Oa@%eis%9D0qFAgKich}wr#*-?g5~#F#-(NFFQLLJd*bFk-u~RQFAf> z-r=B7FQ!Co{8nxL2h;0wE*kowvk$n%TL+5o&MA8SMsmdyo*XIWnwlcZKpmj`2ewbN zi?BI@oe5$SQ(f-W2<rlT+unk{g9a@5>ZL#Em3~AR_)llH6K9O)d-d7<Yibb70RMdk zUkp8r<?{_I3wmUqEFq%HY2U_c6L2f;>s0IGS>}87_pQb7qziyM7~t)uz>xm_Wg3Y; zpQfQ&0Bd0C9}_W@KA0SthFvh=$XB#}?*LD4h7W)hzYmXq;RTyIjOB+CKvt>**eoAR zB!obep*%x`#|O1Db^|g3GJX3x&HI!+fD_=E4*_<yYR9mFGwmw^uYmO%u?+G6(?Zj; z{<Xs;f1zGb5K&>qBm!Q^eG83EerMS(X(1;PJ>2aFXg>#e=Ll|vl-`4F4S2L=JwOz~ z^TPl>2dJ9)^<?ueW2}J%iw$=Rm|m>l+f}C{!UFWW-a-V}T3_~UYk&V9FW5VRL2OIV z4b#(nMc9LHV+auTJ}NzA7I=FlOxZmUh}9{mOHrCX{&2YUzU+uR9?@)$)LUNC{WNIg zfUCmz*>_!f?MM_%g*sO9M0J^ttdOYRD0nZ=#HI_cmc`JT>@v@3815`QrsLgw>#?}$ z{WW1`8zn=79&zN8^L78Ib#_lqRkB~`+2~YW(hys{T<_DE@0sOqk1xL~*66CRWYwM4 zh$!=_NdDoCN)_#tcrqI!kCsb`@MF{Axj@x&4ctmQEE(djZZZ0FB#%eeZ&3Gj@!q7k z7sgbX5_-$$5<hkO%js7PmM(kMKzabZlqo`P!s6->L%2inZT^PhQo;@fblcfpL&w2z z)d-<_Ztjh{wHqI(y^T5{d|oI-HeU9*@C`TDH=0XXe<E47*LlCs@sNsqLghzEUuhTF zYnUI}I%n)wWTWVrd%!a*G31b{Iaj3=@yF5wc_)FHZmYOVJMsOz{HNB<jTVjlm&z`G zy!dJC%DmTeUYwDBv)()Lbu+m9m=9j$V{deWlnDh8R_nVbxGg+VTTmE(VRAH&B9YR3 zH_{0=9HqbgjGgU+cgN=*MSQJ&Rw;hToII8?K)3Im93dyvtJqcWZO58(@!=JfHJlNq zl^|O%-+*aEPkoro(lY#nE%lI2Va9u@|2}!&$Xz7D<>d8l-Og}1@lulq?~0cjm1<%d zoFCDyL22FB6S4Y9g@QXbWskLyW1Y7Lle-`6tN6mZZh3l1MYOJvIS{>bLG7oz=Crgd zA19F2xkeE$4IU74gAH?EzX(5clQaCma=<xk=pqjPF7D=@6OEy|J2DS$FDJ+3o#2Y$ z_ckE5YcWhWdTN|!+h_G%>K-?PlXTFtT?Gj%^w;GD&)cza*SoIJgF=$aLa4r8TZ5Ty zw~b=%Pr}8UMw<S3Z0<R(`n*uh*k^EJBEW1U$qq4$!#0(6w}^OCj~&%}J}D(fZ`WCM z&8jq7@?LB<HQBLuKKTT~k5XU#oKO=14@oOAV+y3}ou1Wg%slb&_^RXOW4k?7!uSJ@ z-4{PdUTO;-oxF{lCf>n_!7oxZRd`o(8Jk)}KW-zBzjjEsx0l-?op(LX^TFwRtGj~c zz4fZ>I(v7`E_pAF=Xgnm(g-*h@-7m~sF4MJr{XlD&{YSDirhb!2l(-O4QKnLzqP!Q zK)arqy(Lb}e!v|H_QnQAjAmbLXrY~xYZy~uS)(rMis*c<vjcFjQOZ2zBB_}{9i*j} z{hLxt?Sr>_b$Om$t-JruGU_W$%<)FIBxWb@)y2oAH3E&eC%TVvW#1v3Xz&}u+I5}! zX`8n95sj?`ujM!^B6S>ZNI|!Nk`ODpST^n*YczAuP)|yvgf>@ri+62&7^1Bx{oq>D zgW#aB#eLX0@HxIpYnivYc1&@*WV3AvVP}|UzWv)PzIR9anP<<g*wA&uUcCMEp%*6l z<(ndkLsd&M=%t^A+3uR%sT_^;FD=X~#;_=vQ2B^7l5?;nhopkH;s*q4^TVB>Yb0T% zr?1fx#vHXg<pSw!{)f~*?&~pkDdiL?@A0Zgytdp&T7Rd=`T~Sy9|vC91j0t8G==-I zLOg@G0khjr2cjgDm)oseu_)O^CEVq9N#2x@Q%U*@6M{3b>H_C3_A_<t1fMhx-_YE} z3QDq%lA*UG@w4LGo|)RX+CA6Zn73p@m|cMGsw;64N2B!4>~-m|Ldy{eP~5IS0!N+Q zaV%;(z51qNinHo?(%R)s`YT<-9CBo`!&I(RHr`YCotrD!otkLaumYhPg7}xj0cVF4 zhm}{i3@jGUv}fR~!(0(Bu6=jZI@U%yi0VTZtc`cO#C?;qhI4rG+OYH90G0X_YKwjM z$A^cjf<!-$iEUcFN_;X|aS2*YTZAgWg<6{1?J-N!-IS6r!hKYc#@z$A-i8IrH{Urg zrAMr2-+jwhXYQ58Qd9eWav+q34aoCk80oNT>Ew2x;g;lCJ+mXm9#*aYvmyHR3-?1P zHh6DL5^hsi&&D;s?Xsd>M}JrGA?yLU&2>~-lVur`b7)@KV2gR9;lo*eu94j;54&w$ zlsC6Lyp&OWK)DkauhQWOl^K5Ob^7$2l%7RB!rP<BR_P|V;B{)ZjUWyQ_+|F=Cp<lR zS+UVVYUIjdvnwa4c2rNRob+&coD-udTbeny5S^X1BVjG@M5xO!W^3_N^5hu{i?G62 zP%+)vOB1bp%UG*xAU+vw<DSxEAa?@!q`O_R)AY2~72zt&K3X6tZF6+T$Sn)qQnGK< zu56YoV})*pn)}D~Sq4p(P2#74xXn!u3OzIH_Ey)+m-SA+4Eex1#g!k3(0$v8D)Af{ zA8E_+c59#J5F*#V3NU<CFE2AkFO$jG)vhx)r1JIhJNxI{FR8s^Z0?cj#i2zsdBmkQ zjq8Y|$SAzm?Yd;mkH8snFX1ie5%q1|`A=ElCaksGJ^j#{-6>a(7I=<`v0nA)rIC(C z=wIEE<@x$!x@4*O>CLN_S8NiN0`WtxSHa~SjS=rN{3Q5bW*xd-^P$6;G5N;&HZFdV z!za^CdK=v_-HWu=yiljKSc~uj(cPeo+*@RR-dqNC@(<P}+lSHV0lJxL`S}@@xs`L& zZsTKj?ZK&PXIq9Ye5bN|<nykBwXp4B8>VX={;cXtgdt&{IXTw2Elq39Rjmz{?*g{Y zx4dY1=*ctv<`=WlM|?XEZ{caSXFY9g2{^nV&Vj@05!<u(HO?I>*kwQ%ZqBkwA0;(e zGR;`hZQxDy5<YKQB+}@gqGRkW<b6|PZ5HTyc3GWFNsn@BoxfS8YUlDneaGBAADpLz z?7`I}G{0aOOQj7C)LiXmZbbi((S}>utIP~HTPmIRes+~siTd-|70HBi3MBz&dbTxk z?K{XtI~YlY;^<;EG7}0)w95%6b$Z1@jgwm(y9UQLa=FUG&7(c4swK8)4R$Wv=ILgj zPtRgjGi8|88Sw*Bu8nAYx0m+;)Qq-s8?G-`1~+;5beuc}d;Te3qY^c#r@W369Y$W) zP2g$#RHSnQS{=0*OPvoQnrC@E2a;2dW5nE(xhvttsEX}n;yL;kr!i&Q?xx%bIz%-< z4QWhZ#gj*tgLl}>aUNMncYRyQ)j~?sI-2!*4?l&^5f`ncomw$mP2E4{Vv8y6y3}0# zFsmybugltmEF~!5x3iv6%U^<MbO+^lmCq%}LuUQ+req=ag41kk>t>qcq!Zht&O7ZH zcP|GB3=Q-g<Gytr|1>4(@+%NL5LDAIrBb6VD0|<w@;a!u5!&G}?}K87zSzr%!#gH^ zf0$ysaHspu$BWxkw)f>Ny2G2_iU_;WT5&_QZ8Vl;%X&!ES;SvLn;#XBdQ!1bA>Frr z>6b5-0ji&ywp~%E&Ds9+<)*n+*U=zQj8ehgTAmf@z^H7c+#M!JuqNcSeY-xo1ktD$ zlmjZt9PQ;cq$!9BOTQ2!oJ@RGMkg^XD0mIVA3R4^Q%4i=HepXj*8S?k4+q=plhd2r z_b-tjI&Evb@wbYTVihmv>Q_n@r^m0ZJS+`o$Kfnfo6YjCB-6G<uZ>KZd^XZsKT@Ac zJ1xO72MSP@*gAQ8-10X5y{;~UcuO94?7o%4+z+FxqV%^`T=B9WPGmK;k&8zpc^mo9 z3F{QXpO#;xceGzLwoh<fwqHCl^^Zla#(Ix{>nBce1Yc$2@^S6E2uFU0S39L+HNhC( zs}QGeEfokA0yFs%R<6MBJ(g|+f}>k|ZsCPHpj!o7>BEuj4meB3l9aQ%E{HR%Qu5Q6 zi!R-8aP3jI)5Din7Ykj9zA#yIcq^qX>lR`UT_>z?#m;42D3#Yel-^>!cfPZOvR&#r z5G|<Cx;=j$%|w4}<TI#~Oe~2sZ!5?KZ!>_zWa#oO*<|}87hJ+E@DllVQsx%x&V9S= z;L&#vSKhq%;@;dZ+n-=^2%>0Zj#075O`w(@NAN%7x07zocjAoP!fl-AnVa+=LUC8c z;?XjxQ4{%|)sV`bR2lus1@|p-^euB81CeTqVRB!fUfI$wz|rhFWVaWU8i|jApyA!0 zLGr8``VKpHf}7=26s3&#I(xcxo)7Fa(>r^&IdVbd-JVrzhN*9E9<i^3T4O<Ji!xBL znR)Rwqs%vNSG=HpZbm6-adL!1y4g*8_u=K@nISq!-8w34|0M@sbuL|K+h^s<RHcvH z0@7JzK3@2vy0$Ye(QUOucb&SnEv7LwMM=`C#j48j_L6zWeJ~FnJ2_uYEzn}@WlFIs z89hY|H`@-|GW)>Oi+5E78{Vzf%1X;xsdt&867;n3_0?r(j!134F)SQrHgU2(xp|!V zm@+y^@WRmgubeX9W;do0UD-RqQ?j&=kLK(La!Kk1NbW%@3Rclm8Z!KNq6TGZF3%eC z8q^m*`*M7FuA{@;TK#(}tE+CP5N^PWy(r7wjko>TbNV_d4qwRNG97~3KOt?j+nRG2 zhfpC_K-+dkAE!AZ$I8b+0bl7E;JzTa_Hky-d*5wki{=Y2u|IZa$P>#5Ahvl1uWAYQ zX{n^!p?+cL$NME$xd+jz9PRAo*tM>v*~?~ykFHNYy0PaK8O9c+p!FsJB&s>{d~I@9 z=@V)T*=gx=$-z&|6AwCF3d;(-c3I(yFpb2wnyXlmj*Q$6^Sj-w;EvFA2gvBPy=?lx zVdeMb3SrFJU9v9a^A7o>2i*y?F)zj};W>0b)Pc8H3A|%m`7<}qLJx~xDe=3uIx0q& z4`pTb`@iq%(KHFJj9x92I;Z|w@C{3@A+F0_k+kjTgiEcJl|{IT3s*g$#}f(PO)7V` z{rnLWZLD7^h<W@`>+#pxlYgz=_%8(>kDPx}^P_U*e^i{j>A(0};Yy!4`E`B!7XDA= zDF3lu<t#zPOX+Xntkd)S0C|1f)7{h6F(nQjomoNimsLnuXw4|D`AgHx`kTLd)R{~l zq$&KvXd?>&uVw-V@~upyO_y0+G9|>3f!XH3F}?u2oVmcnofrjW5)?rM$6k<{olgc% z#fv<kZ>H-Dd1C407}?m6J;cB~8#-sCUAk%|X{qL%-h$1qGfjf@6W%{-$-CgbUgG?h z=bGL*4;=DfhiLLm`2taklGtsKt+@X46Y?+oR*k-0?{cwB((#T9`MGRdW!&}hmY(;_ za71i3_xQzani6pt89Cb=wZ#+8-<FUqN=&>aH2a$HZ^{1w$+MbN{0Kjd0i(BjP}3o} zh&V(F0CAc|&?wm+6&u$Uy+VAzrg3fC!TS|gRP&OkK<mjc%cfUN`nKxL{(H0@u$#Cx zGdp3LFe^cTM-<#;P2tJ)M#y$8vFX7L5r2*@3SYzv@Zt{YF4+t4a$bO!NF-BJfIRiD z$TQ}@{0Usa?U=hqx;0(U&(lNzR{{be=>O8if4Yk^N}1V3mx+9X*le0kTc;tWaIxPZ zwgAh*fVF)C*+Kwv;xx$EA$TUp_S2*kfILmg=hGm=8R!g<O<?;U>zeM=E9mrKJcAV! z-6Sc$qRmOdD|fG+P1eWc_5&l8${&yH^o)9z4rdg6A0&g!GOu+qu2@s{fkdVKgh>|t z^kY>(cwx0`b!D$&)Mm^F@EFUroVhBmmD+2ftMGmNly7)Ufs2Cw(q;8@mmYr7&o0Cv z?6+m72>9Tv=oP;mV<aYdCK{oPX*L&7Y8v8@X^3GeG85v4Um(^o7C^jbO6bB5JOYrP zqJQZ%y{<j*FP&n!g0SFeCKOQ1OL7|ME`GvHppB-1hDWD?c9;e_oEbk2G{T;yJ(Mxc zUK-zlPM-!EOC?PM4U&ri?YI23)5qyf&@Y{2(faA7O#&_y5JPu*X{x{monD&fSFVob z8cjn@<0VZ)y>=Sv4t~4<>JiiwP=h!G&4k){8fqewG!ts0U;BXf8v}Cq{o1F5Arh#$ zk!7y?nvyZ{8$YawV*5~r`B)^&a}G9C>_OLqCBl8!W0X6D8q%nq(v%Pv*!OkM6TnFC zU<cj_uoGbbvn3925U#zN58TbX#3tC+vq4N+Y9WYavqiz}0&41j0Em$QXxs>RWs@9G zv5{j(guVZDgl5rMt)wQbINVDD2sUft$)Z$N8W5vXLI8Q0v^MKWZ6P-qCZ&YLkpZO? zO*4Tog)$4a0XYBxLZxJ}%=2saNzLbH`~n4eCKSM3dL4v!<9}I*{mg}Q0WjbuXeJnq zzk<OAU|@huCvlcSg0%q=B|uE{>~CQHH|;C`<#{d`|9@^Tn>oK=JwN9E`uxv!F~j1L z0`*TIzKKrpfzzED1Kd?GRPg4bxu=s|^rCOBTJ~-T`v)ZqFA7pb&c?z`WoU@LAywkJ zg07+x2V^r;02w@E0$m{4zpq_v<fF+>P_o|!pzwFhcbaGdn$!S13;2Xg0BP(0MF;H6 z5i-!fjxZaD;Tco>+5rYr{Mw<Ap-as3Y;9gJ^0s9Y-RSdp)6$gP2=z#L`rbL1Z5OHt z=Vo8C@1xSOR8LgT$RET(h9YQ4iRG!jgpYTyW;cnEx<fu~pH@rguAYr*ZtA~EWV<G9 zWM<U*=LW94RSJB#`UK&PsGFw*uP1<S&QCYuN92GTfR2_iklsD!F<ia<NU>}V&LwNY zpfbCy`(fz86<yv|BKb|s)Vh%8YMWIai;Ra2+yPf)3*7mM<FyoU)ld!sfjY$K;Nr=H z2Cgpnp@|*NJmFEYPf}#Vozry|oJo5msI~HT(o@n_9$dw>M<vLIL0BUOJ(LA1uEzfO zdRPT~CEw~zhCVH54(~wEz$vSUU05QUbO|mt?L)=^tCoieCX0?i>uOm6rmt#u!y;ZW zlSdvj6=JANTnN0R`-Uv%zHmNX#@dH&i5`sifd;qJIKPu|8%nFi%<sp%8Xd=tYupJ? zP4a)c)pWT&4X48;MS}|cG_ZH}ech18K%v;8H$<jogx4%>55GDL*sJz4*-0O0l%IWO zYYXA|Gk?KQ5is8VZyUU3x8Uul(?9)~V)}8+*C}Q^{`HrCeEi#pCP04(cw6}E=Ed#@ z31ylB%`@XiU7#EQ!^6{pY#hV{j55g`)l%nG%rY&=-g&pl@JBK;L_T-&rB=hS>)i5= zH6SiWW7fuZr}!zm4X!OxQ*OOV9QVUGUGpgH*Y=vva7A_Cg#?J>R5vh+(U}r5$4qB1 z)#Yyg?Diq@M1FfJLYE+pdGii2j$VG&nTrKR%*K)9h`ok{tzLWWtGcD4UU6)hc;1Wn ztRd{vjXN^~ydMhlrC%32K%>Q<3KrHw9i#PVjxMj{<ZF<%a1|X#II5JCz=|#P{*E|L zf=J~%&KklA!!RkanC50glx8#h2&R_pi|Dm1f6BhF)<kFQ%itHs$L?!7sDAupr+HDw zQ=G_ZyMtoQ$C_PjR8_Y--rI$p-30E-_%uF<9j%S`VnYLM$oRa}1uR*ez=Hv9MLXFU zuA`bJZV6jWNXN*mliX5>O3?~~lr=D0Ztme~jT<7P#hK@!j8wl)ek>O=>8WKqBvSL( zmamOp1}?7N?L}u@<$Vf8Iv2qM-oQ*emuug3ti$BU;Xacc>^-|t7nKy#lFhkstk2v= zz*JUCRIA9bZ8~6RV51qaPZ`Jz*~aM1@{*K%1h%FgoaYHYCwnVG9fimc5%u0`FeQ|| zl8;v$6$@RHxjbYA?u05Y<n)x#5mT%Zk`k;*4`*i9#Ha@A7H+Eq0k<LzYT$*~OBy5z z6oQngxr5BCa5urQx%E)LTx+B*A3RF2Cw85f#4tvL`JV7#!_c2@SA36~gglyjWOV6q zi?!akhCkDhWCn_oqqe;=Jf#soeK~&cyLA{^`Nzj_^@32u-fQlk2Ek7W!zP7BRLuRn zDD^v(9x#uk4BrXMU-aoQZFDy9DiLt}d0D0O0;$J+Iy(vD$;_tvCPU4NK{wZ5+3Lc@ z^0U|&SaHZcdT+9QEbl|EA4h9ct%t7-p%Rta8==oFeyApfl>=8UU@f7ckUTN~#C6RD zD!&jN#D2(&Xq1;7ZOk|EQp?PBUmH}c{l1KcVJm|;2a!CZA)A;x!Ld%wz)K~<(}25t z_d)h#ksa&o)PXpwbVa3%crVDlq~WYOxaas0kgfo8?SYh{7_L|sYx20esg9q=%kvLg zzBL@OjALbRFwq+Uo^P4#h6{E=b(m+fKoYCeovr2k#PsMjS0Yz}qWSSd@*nDbKm`nq z!vum%Gmv(vFn5z_Eg*}Ogew;~%z7KhYMJJSs_=zJ2v2f#>vl%cHMa!3g~x?pz9-3! zNF691bP1HX2i$+we9=o_o7DMmt=#Z351dFL_p`ZL6DNUC+@?`N@G3zwkbw|TWW`vK z=T7FkgoB*2y9EX6t9>tMj%SQvJ8cj8hmO<|5(x3V`@s(1K3eJoM~D<?*6{F6P>)|i z$wIEhd6V70=V%`OppK0Dv0oh|vQL$*n;>8B0#Wzx%^AR^x~?71!}r3>YLLg6Z^U`z zOKAyjLEb$-a(Mq_RyU>BYh>k+!;=q=;z7BO;j-ch=(t9|D?e%Xjct-Tr40&jo#Yc~ zaPw>2znDPi!{T9zHfms*1ru<b$zU*Kx%HC0>thd^7~{0NniX13m)0JmGm_SRIR|Jc z2IBLFH*$}kkAA8ZJGQa~Ffn1z!YLtj4ZUaJtriD<-=wBn@%A}%o{&UAk9$&zaDe+I z{OW;Nj5`{Rssmd!H4PrjaBSl!h<3S-c(~wRwyNXN#H2>6%J8nrt0PZ=%s7FmJdBa| zc4YeX<^v{CK_#lSM&BZzLD2_sA$~L`6_6rKs=Z(tpdJz7o>d%wZp8BeV-sNhJhj@( zUBYHZa$k)P^YVxXLUp%YDTaLG29iY)l}U`pVeX0E%Nz6+f-=r@8By2tjCv=*#nvnw zdL7G_GD3C9H!^UnQD8he!CQOmD=T&;y7@0iZpy{sHdAox&o)?HRNq?Z3cn}U@I^qm zjC?wc{~I!@PT~=D!JfPc*hvi7J)IzYz=KDR$yc24AiA>PpvHH^dSo7<<`wV2CqKWu z^=-7Q{&~vb6<loS=w_9*b#DVob@3Dsq;`H<!2y(R&1|4o51s`#Mv>4U2H4WC%m>cm znYK3qCGzf3-|8P{_i(iYh7Fc|Eb}Q^Bd|0NtVB)alX^YmK5Sxl{A`W=`7b#dRK1Br zX#%~e$V;>|W`YtX(D-_|%Y<5ROOPs7Go-l?eJiz-|6CK?0lM+y3bwEk@WvghO5@|H zeXa=(4$r97epyXxg5EuOb*e~b<1w?Qz+z_)+Mp>zrah~e7HGYfAE3F1<&6`&U@hwF ztTFp|j;(z_zEZ{VyQ8~yojR*?q^d;Z!6W~ej&YT?JC{r7O+}8DTa)#lex5V`TC{e~ zp9-p_YNruZ{4rZ1x}dLh<VBYm_BVZIicjKjUMUYwUlH6}&=4)$pYeZKLbSVWp%Htt zg$h3p?)xLO-?@u47wUX?p2h^bK5n_GliEh!77ucHRJ^w6vBz!S-n269ClmZ{y^w1J zm(OQl`B#9Ht9y{tl2i!Wjj-2i%7I;H2K-1z#hk#+d!UFZJMa==mo}P&Qf1KhR=Z)w zh`S&!nw@`j2b>79rNKcg&HONi=h*^NHlsLeM4Y#qZ;pl#gWE;gy=c+#Pz#s>9Ouqn zc-;)o;w3)*tKnJ3faYjmx~qv%m7M<4T5Yk?Gh)5aT%<G6Z5LV$0s^1-kM)mZylJDR z>Z?KiEZ8_T+FrK5!+Fp^tS0FOOtTr|J>+6Q@?#C=%FJBwng6@Vl+f6#dGai9(^79h ze&;30KM6A_An*Sm-;I?EOdrS>!SL1WoY^9=a1s%vg8gY{9VN_KMWmDY8(!W$#aj*g zO$2Rz2;@{4X`&A5-r7~h;gTlfz}c`gp^wC1+f4~sZA4F}o`lKSV3}C{8vDLoSZ}Bu z8?0Hs2c&d@u9wh5(E?F72~U9h7M;jRAOOgMHAw`(05||iD$u-YyGIZiK3)aA*3i)b zH7CK8=Ww4aZ<x%@H=?6={ER@}lPmAH9(IZx{B9L63Fd+k-fsqn+I1tWxz%G9fQa{Y zfU$shh>n=p0hcBK6D=ArM_Q5?y-ZM<>)^s_79Gd65g_<J$idYH9TxI@b=`rw16aQZ zk1sO;$y3Nc8+<PnuTo}&7J#Spu5|zw!&dm`5IH4ad<v!_Z5Zlk5#a#w(3qrU4AQlc z32F!1%&Uy&SqKy%jjQBkXv|ZGSYY+pBKGLXqZ2UK4xG3D#@a{|<}^YBF#<a*y&EiH zEs;A0_bN3;0Aou8h4>}Af+aIVbU}jceSuxrmBqgdaTZ%*Uqj6S>;?mv#mfXh0Xutt zlKrGG_a&^xt^x-$A{tU{h$(?57nyTY#!2i?qR;?#n#4*uQ$o%KK$)o(`SI;RLN}H_ zFL!jL)pQ(AaE6*GtZ+1R1IMcg=JWu$s%^CfVYJ1=xr%6@MbC<100RZ$-{0EX1=P#p zz5GUa`~u}RIO$#C0@vj;8!+y>NH8c8VyLD#0c7?`fP>ltATj$rJdloXM4hz%&JI8h z#-}g3OtY_r`}AtUQAi0}2Iw_coWvLdBy58C661grm<VzQAn0!aAY~lH&J2;zKo8$E zG&j`Xf_2hZkW}bo8L;L$Gjd%g0V>Jv(D=nr-#9NoKC6{%`Af}wHZss+#Fv!D-*jM9 zgSm`lfG_4sJi34syanRnFJO%LiJ^vA)>r_V^B%0e0YE_v)asiO$}mZ&1Z?66Q%)dE zi#ig$mxpY|Uqm{w7!Dgq1@UC0V+3t9;BQ?4IFg0X0AudI`|NL(Fcm6G;5*x9XrVf_ zfKZRLoM&j&m<FVyqMBdJONu0u+K@?A4w(1gnNE24SUv^~De{*kJ57irLpF-+Z2+!6 zxn%x*#8Ph5h`AH^89tDhB7nY)aucr#<RBC(Uu|?b#$WF3I*@lPB=}6#jv%%mYFIYj zn=p^JBa<c1J=W)e+~9fAnGH^#u6@p<;ck9hHB_aoKdb6O;)#m4JT*Pj7x|RI6K#kT z&zT>sewBO04Ws~)ia9#j=W}##3|=ZcfuxmIRcFoa+9J_bxV}aIATL^<Z<JHNyQy!v zWs_d26!f}IC)K=HPc<py%>R>fpc|)z&H${A322RD#BybT*ogxC00*8sQEF^OVj2Rv zsv``*UJ}qNkZjKu>;n~6jqqx17r0Rk=8b|)`iOTPEw2XwtiY<Hb?jSnR_ve06{JJk z@`K}n%L3TNNGWkJ5<)@crT6IWS^uAy(fn5r8CDCPL4j1)qA1at-9{pfeGyjVV29rw zZs#{iety~YAWB~*WpD3=XW47;lmYad^S_{Y<HkENxzJC%&Ya-MlAZp8V5k2iW$g@N zTvM%(OPfJT-I??k>X1#Xp1~oQ8RVaRhdh@bk96K(k<n~Rpn(sWPT*!fT*!c~<;iqt zEEp-)-1L%r?h#(>T)_J<jl$!q>{#&509o~AmZ$n<kKG~Z+$8o`D^6gUwGi_{Uo2oZ z&uUPN<^2HjA_0h}w8T3i^BwQO9VquX;9NziiSyufFKNBz``fHHt-_5DwCu0ut4u02 zN7<aZZG7z0dt5(>WK;ZpCpEO^{k$AZU>bh_o&<$O%#s2rbW8<u`VrHZHK)YD&2o4R z5j_67j5j%%ATpRl7oDIZs3ADlFgPr=Z0Qp^Ta@C-E?oDhMX%50>R(kX;}J_aa!d$a zZQd%wUjR18UQe(adcje^lr<k%;NPDv=3`Tu2~ejlhkOe?3v7TOR@-0PVIN0WXR+5v z<8t7zFf;;0f6C&OOgolLTeDfEWi4+H)3aIA#&HhmVy@SHQaiO?3_2k2p%oe}G>?kS z86CHu06sLCB;U7NGbHS=IGx<Jf&aoGz=8JvqD?wV{71_Uy*>fYax=2d$;fZGh-W7w zAG0299pVojbutq|M^{zv;lA$G)6^0f4CZQDURzI3t=R==+%b7eCK-KVeok`Ps-sE5 zd7<IxY-7Sz?Cq$nOPs!#+Y{h!g1p_|n|sU)ij8tiFW)*np4`nl-tj|}yk+&MWqlLV zkR>!&)goSw4@M(BKUg|~b19lJU~|>Oc|*uiz(qb|W14=LhxJK*zN_M$afM1>ct?Td zOynY*0nAveA$Js}6vMrXvFS(Aa#Gs(j*};K+fM53zd~sCoD4a$#+?;Y$vng-g|+O_ z^(EN;b^!cdDF({I@CUpLl$XbvM6?{aL$Eu+BZfqY8-_$TBoV|Rv2nnb!tfgfIHE3# zcb&+;47@*-p)>F+-N8vw!ZCtq(W#r7&0e((5*tSo1%a}JYJw3E-)*A&mYUaG=gECn zL0Hk)r(FB)^e&U5`U56A*;|waPBRd^nH$dni^v0um<MbEi1T}bBHD=s+vc75kQQ2w zN&6D?fK>GG=v^=aPXp@8-$r_6G{}~n#oH2@@*0lEI<E!x)u&@uWp$1?Xl&{lb@b=Y zGQm^OOg-MAkMJi&zRPRWlmt*IMTaR_=zHwzorHI^w2$#S-HvV&$Fb0Yem>mQk*`C0 zB{DeJ|GZ_zJJZil{Yb&cNw}=U^SIQ;=!2rMoA{ran#d}o2n~^o`HxRED5G3Bh~)x( zDi>Hv;Dn$Qufn@F2-6+`BZdsl1uS4QH`YrdZOc3F8F9QOj{qD%Vq5q@ux;3NdE$*J zA=lNXV}~Zch_{aQbZ3<^bflj49o#_}Oa%@!<B?`EaG;GPx^N-jKqL8!0|~-*0H&b` z>SIP37MVPaI&%+#S1J#Vr@m73-Tp3jKfv?sXAlMG(6;-C?VBBfdU=0B9ULyp1FrNK zwYNPMS9xY!->DZA0CCY>A*P~8hdXnpgx)~nnE&vQ4X+0#l5PDxyd1@qhgbu{8q3BV z5or`=x|rzENiq=7*<QSp7`9vL<IS*6<etL#6s|j)9lGtLm4acG5d${?i5~rX(|6Df zytLl-iKxf|ufpqSq*uKVNa&04e*vGxT+>l)*A!=;nj_0~9+A9j#LJox>oI%3rBY<@ zzAIq=tLpPM?0zA#P%*&h@N#Ut4ZY=W4Q6Mq8*~QK_<3z{5&oVo52iVtN(T<XvN&Pp zetzV`IkrWxtn^DJcR0lxQ2w(yrR)wpLAN<9%79nZXJ(wwAvZuZJm6sZ8>4!H`zL$` zZtTr_g;{{F=Z`T5h6#W$W3;>O(GY0pSH50f^g#JDKu|ro;N<dR)N-s`;{r_Crcf|g z(1o5okaKWh2BHk=^f43<ay!r*#ZimqTyTP+<dHba2l-(EpoCo1pL+s|9=TWZV62qx zrw(d7+_<-O_I+$o<Ci!ij`PR?gJmf4F@o;cI3e*W*x!S!&^I|MuKX&&FHm+k7Fg)l z<)K!RKN)C&zj2pw&%tNNjpONe@p4s&Ctr;lr3F{RN4EGeb@);KzgPQeno3^-{A5-N zcOe-k-MsyMRU5elkrBggBD0{x#FgeI-Rvu$c_&_hZTlY%Gk&U3qYM)t$a_mqm7HTc zL}`{J^Rk`-m8Qp!syeO6($+NHwofqp8)%gF*(~5z@77n5k)(Lno^B7(gx4Km(!4E- zR_kKf<k!3-Gi>|j)dI{9JLz@NpLG#^F_#Kf@!}I{^fFQ!s7?#9A>ctM7whU;ouaNW z_n=*~O4cJB^59j!;aw@wxEusjEFnnNVzV*7J^%!k6bMTnmj!AQ+;xFh$$$c<FKU5> zBQmykPc6HeyBaRHHR1|4!-Jum#*>H-WQ+BIMVzDCF~7UPgA&&QW5x&+XT^7{FUpm> z$f|2Plm=$lwe~$}>L=|@dv1Q(VA%7|Ca<-l%s)5&5B4xTRmUIYie*hgM|D4^AlaOV z4DG$X{>bS}w5R|4*9NoaF>R;GUt2Xx9Qf5A8xmz-A5UV+;f4v%h%YVSs}o8W{1wEv z^8K5*q=eW(GujvaVT+i6|9_zW!*B#5FK*y9+Eb)D_ECm38%{Z=RS~P4nqT^mtL`Lk zuMiO)PrdJNzVN>)y7<@PjJk`A3nhw|QUl(sg(LO9zTkPLi|v2qf2G<pI)SAi7tvE< dnfUJGHK&dV3BQQ{FRMfT70|!=j+tuwzW~v`YM%fA literal 0 HcmV?d00001 diff --git a/doc/fluid/howto/optimization/timeline.md b/doc/fluid/howto/optimization/timeline.md new file mode 100644 index 0000000000..f0c1f1002e --- /dev/null +++ b/doc/fluid/howto/optimization/timeline.md @@ -0,0 +1,27 @@ +## how to use timeline tool to do profile + +1. Add `with profiler.profiler(...)` to the main training loop. After run, the code will generate a profile record file `/tmp/profile`. + + ```python + with profiler.profiler('All', 'total', '/tmp/profile') as prof: + for pass_id in range(pass_num): + for batch_id, data in enumerate(train_reader()): + exe.run(fluid.default_main_program(), + feed=feeder.feed(data), + fetch_list=[], + use_program_cache=True) + ... + ``` + +1. Run `python paddle/tools/timeline.py` to process `/tmp/profile`, it will generate another +file `/tmp/timeline` by default. You can change the path by cmd parameter, please take a look at +[timeline.py](https://github.com/PaddlePaddle/Paddle/blob/develop/tools/timeline.py) for details. + +1. Open chrome and visit <chrome://tracing/>, use `load` button to load the generated `timeline` file. + +  + +1. The resulting timeline should be like: + + +  diff --git a/doc/fluid/howto/optimization/tracing.jpeg b/doc/fluid/howto/optimization/tracing.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..3a49fc4f8a401a9463b0157e2f38c164ca02dcc5 GIT binary patch literal 30668 zcmeFZXH-<nwkW)sZgS2!NR}ud8EFwoB7#bg*dib~gMfq<RFJGh1qA^Skt9eKX~_a2 zk|ierK|;5r20DBT_ulSv&K=_$_rCGQdq3VTU8`r;nl)3+sx@a-HN>yPIe_7Uu7NH9 zfdIfm@E<^&2Lg10U2g(_kr5yP000F5gKz<)AOxZSnh@^aur5R#fc~l{0e~1+0QM)& zCGhjOfYg4q`Qwx1ImurbFuv!|zhGeP7?rpIsF=8T2YCCtc>4$_$eaRH&KVfNjwyoW z7cBA%WEkd67*7VOV5@g1AFpH2MGzP0uN!D<BTTLs>l$3t`K1WGiM@}H2MGfJczOl+ zUC}!yaLv+6fP4Xzh76zwgaKB22Y(+e6O)U_H2?biMgRA0Z}gY!z>xGYt>4xEJ%G{C z$=?A~-ULLy;pp$+3BuI?08?`C@e2R|(si(2C@8=OgvmfD0Tw?{KoAx_hMoU_RgU2s zf569bviO-^(FSp@gIz*nZ|~;}0CdN43EXmU2I<fxfbc0VM^`TpZUSLV2N!!s5XOP9 zl&7aR2-Ay$u=tICq2v57u)Y0_-!$#*oqxlB=mNF`8J~Cc_wld~`Ss<0=gre25VY5? zA_U%PT>SK}fS)YKc60YOJcemN_@md2OU59~1j56vphkbf%PxWXW*~fwk9F|Z25kwj z0RRb&qy2e95at14K`)O>$96of*SzVfYXHKa?MX~s0}M<-7;H=8=;wXqPx>UmPX0Pq z!FrGvNwB>i2>pQ*7wB%L1;WQPb6lJZ{-W7<+r`uZtOw;J`EtwE{CE`5f$)^S`_*F^ zj`^<LcF{hr|Apfhc;#4Tu+O2uO+VdZ9YGwZu%|x={o)H%b8$5|#s_79UJP(C)dyj) zE!57z{+DbVAnfMkf9+U?UvfG+o!0?jP%mhdlb6|_@<UU60?r-NKgNII<MGSyk7+hJ zdFUP23xe=x|65o8YKsm4oeiu9Z2|*)0!)tW1Ih+takIa0?9-qOFbUuypbeY{1b{&B zx&e3tZh-6O^}2QN+aFJK0eip?a0Z+J=|5}!YGLu~i5Gak3fu-Z051^7|987{zn(e) zx4?S*@7h1wN&`24J-zj-g*&hea_|HU0T1wdAP75wZT{kB4txV^Tz`N6i@O6Ts|R=; zKl?9r0`z~?k?D|~C(|d>B|8b@f@#9kVHd&sS@4pFslZhKlIw4L(q+;C(go5H(mB#` zglm-FUl@PU1B`$<P{u#3;sjdkU$ldXf>ae@=U`_*`XFr>8;lE(hMfjEX~EziZ$+^7 zUu6yat?}Qs{L}7~rzlA%*(uLbivOnsbP{xmzo{HsN&xuV7JtZf<F~c6e_QFV{``yo z-%nftM-a>KKQW;Xpl_jF&=zPnv<X@V2tZNLdgv$U`(yY|{<^>V#`!n5E5CZq9n{M8 zH_l)D?7^C2p2xlGf2!ov>M7ixRDb}d-~iAc0c~%e5I<LEmjD4RaJF_5Fz|AakrI$U zB?r!;$8-NN4ghSk{hHGuY(0P1`F;QZ)i`k8IG%O|j<Ep19-M<3mH~iy@9#RHpP<`6 z1c3Kn90L7r{ecHN{vZXYz;_A@zzOgHLV!4^kvyOTr~_vK9pD0B4449zfE{RO7tl}r zfFK|Yhy<d6c<}9*4rBqjKmqUyC<Cg1I-m(?0Xl(R;0rJUOoG0$1pENdz#eb}fk4P1 zv=C+pCxjm&0+EEsLsTG|5FN-x$Q6hs<T}I|;sptWghB2?;vp%JEJz-t7*YvofV4om zAp?+c$UI~f@)L4M0wtj#VIkoq5hal&Q6bSHF(ff1u_bXP@g=!Ua*rg5<S|J;NhwJ! zNef9Y$q30D$q$lUC;+8~vO!NkrJ*WNZKyHS3VIXj2aSNnK{KHF&~j)a*waJMIp{j{ z07eF5h6%u=L4VbUnZX=jzOV>b0xS#m3RVkihYi8zVQ3hEl$Ml-RFYJU)PVFFsWWL1 zX*6jD=}XcY(hhL!ERpV!k&&^HiIFLTW6zSzjVzQbf$SMs1=$C(0kTE1U2<}API5`| zGvvnP_T&NN(d19aOUY5>gXBx(SPE(iehPUCT?#7-FN!FN42stj%@l(aD-=hR43r|2 zYLv#5j+7yk$&@cB8!7uKmne^@7^%dl&QO_Bxlu(@J*FzB>Y$pU+NP$a7Nl0CzD(^( zeTVun^&9Gs)brE_G)y!SG}<&aG=VfpG(|KYXvS%_X=!OiY0uJH(FV{a(H7IT(az9f z>6qzc=nUu_>5z2UbhUJYber^)^dj_H^mg>O=`-lx()ZJ^Gf*;!GH5f{Gej_CGc+)a zG3+ujGs-a<GkP*6FurE&X8g`X&LqmD!{o?xkEwvEjcJh?$}GgJ&FsK@m-z*A2lEmO z8H*T;0gD?;JWCnNXBIRo6YFVKbJk$iY}WUzb8JvHQ8oiM54L2sDz;I!Lv~(vEp{jN zSoU)ELH2zPZjQ4YP8@L@6&xcRhnxbOI-KsDk2vc%r@3HU5?q(Lg1B<HI=I%kS-91? z9k}DTtGFk4NO&Z8uJDBN6!3iF+2!Tq)#LTyeZt$uyUxePcb3nMFO9F6Z<U{!Uz6XJ zKaC&7zb3#cpe5iT@K~T-0DXezg#L-Z6L}{-pEwc}6EqWy608uM7NQbT5xOaqCe$i~ z7UmZ=5)KnC5uOmC6j2dz5y=$k64@6O6SWYH5v>#bF2*5dAa+}<L~Kf&Rva$wE1oYt ze3I;>%1QT=&rW`kAdyg%aFNKC=$9msRFZU+d?q;{MJlB#<t3FbH6~3XttEX+x<q<j zhE3+8Oq5KW%%-fctd;B|*)CbaDaBJBr(T?zlw*>+AQvUqDECudLf%0>TYgx9RzX(* zsZg(gIW2kG>2%KNuZqlyMv5_tt%`UhWhFnQa-~&eVdd+}PnE}1SX4|@5>!5_lB%9l zMXENb9;hj)1*lc2ZK+GEyQ{xe|E?jX;iU0GWATjO8T&JNXXZ2oG_Pwu*PMf&fZN0K z;R|Pl&N`lbd3HtXq?Vgjsn+JXQ|J87)tuYcR?`mG{%{_8UhjO|`Oi8`I%YaqIy1V0 zy3V>Kx@bK`y->Xm`lR}X`j7O-40sJ342liV7nCoAUuZX^F}!M+Ww>xr;-cThMk5lV z3r49%Q<p?8d0nbA28<1iQ;nxh#7ulm-d!fUY;rmK^70kAD`8i<t}<UmTz!3Y&-9#W zvgxFmxLKfCt2w>7jd`*8zQuWq6pOiQve&||^;&XSx?0v-kz1Kty|mi3*0D~vUb0cN ziMAQF6|)Vt?Y85x^RRo5phwsv-d-oYZgKtfb)3DieZKwf4TBp`Z=fBt9Wot$IG%M( zb6j;g<CNmGa`Vj1l$+n3&p4+#ue!ipGF;YOwOzAZF>d;9xo-RJ#_mP#L=Ov(3Qr2p z>z)l>OkVC@?cRLe!QKNtl0MNsbG~Z6>Aq+`L%)~)5Pw_$`T!O%$o>>49(X@+{+8yg z>|59%^PsoEjKN;PpF$);VnbGL>)tL5g@rnVwuYSuiwv6&*9y;zfJEGgXhjMk?;)4& z=-qi0NfqfH*%x&xDkW;?uEpKPd%XAV+*`b_cfTZ>F4`}8EJi&h_W|sI>x0j+3b9XO z32}~bJ@KdFGvaXx4hcPpa*2-<iAgt;J|`<B=R72P==E^)5&Y4s6vmX>DT}E_skLbW zX>n;g>4^01jMEvpnN*p9nRAaXKCXWv{N&*iT$W4LNcOqxil;nJW1sFnb9^?Kb2g_u zmp3;)_wc#v^RIb&dA0eX`56Ud1-A;8Uzoq>DpV;fe#!MR{w2Q1t7z`kl~?V>%EiU6 zd0r=%kdy?Ltd`o8eknU&R$ne%o>#$E5&s5w6ZmGW@_OZHm0?xuTa~xv)gskTYnW@| zY9X~DwdgvRy19DG`hkWE4XuspjWzFN-W4^SXv%)i`aZdts`*|s5p^52`@!eKdW&<* zVk@F`vh7;iaQo%<&mD#xA3JqB+q$&6P#@2HZ0c6)Zs<|zsq0nht^K6<sise{ujaGT z=h}Yd{`xPfUm6EA2AT)qgDpenhdPH1hI>aYjSP&MjeZ@o8Jqj+`1Sj^$2ex<*2K}| zohj0(xM_yzj2WJp{MnPU6>~~+@8{3Ye_FV@FuCZk_~V=ZH{8;_W!mM875<gt@ABW9 zR&`gutXZ!u{qXs5v>v^|xbbvTe6xD%?A9mrHS`k34@3N!u+6z$v~zl=b=PEfX3ul) zXg?0iiG6jTa?o{Xakz36gd@j2#-GGD5)27bL{B2o$KKEWS1bSp?<6<DtodgV0MJ^3 zvBMVtptAlWH~WPF`6HhMVaT7?arJ+|f8=k+55O1#C^`!OW_JMK;s*f80WT2{rU5^X zLvYix0EvRa?+dIsF8((a<bMGW8b2^J!4iqvyZ}Jn2mp9LBJt=Mk%)f*#`@y`@WJ<w zobVX?I0t<$aSZiqaNaro{MSPK05DPjYcM`2L=Yfhgg_Y~#CCuml${g|Ou+INGK2&Q zBPAoJproP(5o#C!5(pGZ0)vv09w+jUaPT<*V<crdA*V&geA%8{(3eF%@=+dz(7DPM z)++;OVTBuhQIu3{>>Qk2BBEmACnZiRDk-a|s%f9s(bdy8xNz0f%-jMD$s8P=ZaTZT zy7>nL-U<p1xqbKE{pgqnv2iJ>Y3Ui6kDp}a7rZEZS@f#-b=BMIn%cVhhQ`*m_Kwc3 zkKH|kL&GDZV_(N7<`)*fEiJEnUtPoe+}_#U+s7Up9_s}Gpns_KTeJVrixJd|1O|h` z$dC1ckOYAjlo3XHLXM0{>oU2$FSDS0Bn8X4M|qVkltKzu(5yH72B_GCPtS{Bj#c}m z*?&*5sQ)F+ek=C3UXx$|P4cTiNl2iiP$-m?j1(+n6kxPPMn*wJ@vBh%S!jM0y5mCs zyAVMl5Re881|tXmGf-1fGyIPWaT?4h&k-j88Yl#GCMY8S2k;;B<HUjg?o~SJTk5Xx zsZm?o+uLJVG{@03=p!oRGHdGn@7L;V5`41;OEjsw_P9f55JX_rrW#^QU?c*}MBs~g zKM?>NiNLok_&h%bjwkgX0%$U0f<gIA9I)^obp`*m)yDVdbVC2N`ac@}L+oEIy|Gq^ z-!0Fy_b;A>r+d0`_esvx4ksGQeSkfMe4+V5E=B|>h(Ih6_%7*81n$Tqe;V<lld-vk zs9Gcz3O*Gc^VR`;|D$f;zqY!yV?qx0^FPG?vDtt6z(20>&u}^(75~f{|ID=ivlj~S z%O9~xg$gzgquqv}(oht85T-4<;!#*r#kf4xo5V<d9Uj{#zW?D&ZAGZ>%T?vZ`-KIw z{qC@oE^jtis@h4XME}lZA?TUD2*dBQjA55C<O`()WsLD%Sq*>xmr<jKX2lP+ZnD0~ zE4X%3#B}ev{8Z%`ap(s4F02f*!blSeG1?2#xY`W0DJ_rFUguAB)XRTjz_(S~BV${n zX4?=(7G3=4K@=<a295oJd<KR%CoU5KT1gCo@F1t|V2TL*WCg;Lb(~$<>$oyfonF~a zoVmAeq9>K{bFQ1ZG1_I4$LorZ@058Wz;gLw`h?2#h+#VHKCb~UiVFz8?k;qtE_j!k zUFs2yhV6shF662is;)-%5Y|fsl93n|q_-Plhcc50@DYIpaj3iTu=FmsuLk*e+V0iw zZo45*>&Dou9=xlo3Dr&A(j9NZlab!Ocxg~KsmJ6Etay!b18V~Po*xHDe&grGIb!w` zn%w=-Z|9^K&n!hszH%D^zDz#!>3Mfod?Hv%o5_#yl6*uu=5tx#bF5f9*HVU2=d&W) zsgB6hW)91-BwQR?_A3$4svK&jpLBl&>}PHCecESombrMPTB7tenppr@ZR~%RryEO- z*>c`{j75dui^jU!XPx>H2S4gKiNH%+b%(emS-1;0QfA4?^5%#@0++_^jQWQmjh-6~ zrqL@>rl;TCih9bNZsw&_6%|C^RK#4(BwbSK&w7G$xUs|`s)s5CvIeC&OsYYz#1R39 zXQ7=FyR)%jh8Wrk*)4AWcdkz%<EG&yRIgP4A5tKk^hgt%YlX>eHDYcj2-YiGML*xK zkh<qk_^szT<b#<B>)G!B>-xwFimHj62vGHYutG!<rtyXN>=h#LBE;?z5m0l^-t-D{ zh|t7*R@kt+f0on=?TouKNkGXGft_-z67wcdH320ePyzNDUNYm3+H36MUOxUca{fF% zxviORFuG`g_5xdustlA`MTq9Alm(gI6vfxlcfnG+XW<TG^W<%1=3{$zFp67k#t+8I zT>Z;`$d;^*IWMPB$G#xj8&HNu`cmL0c9``zx;_m9+ef|;16?NU1?_AL4C6=2U!}wH z6`nfxZp%$>t;{MOQXa_@#u|e<<6GwX*(N#)u&;FO7%Ab-?@c+8;&M_+&mHq_e(WZX zB45sofxGmDwh~06J?YyfbWqWYX|j8#hBk5on|&+GvxNM9F6JseeKxh|-EFI!ps9c} z!8&7*=3fHJZm8F5oGCG+Iq$k>DD|lAB*e1d4#Se}=!BkX2N1C}e6^pnl*wl5W2H7z zW{@ANnW<xnF#zcnEv!f<0-p4YcG#f35Dd5LHizi2)y&7Lq>KJN`DXps2coa8`&y3V zN_oW>(K$QBjDXnB`1ae`v-n$)Fn2z9vXM%C<D0~L+5-=T{Tj|`L7tfK@Hk%2HN{FF zUBuEa>THa~tlKwy!RcM+=zob2H}Hv}bGYu)VoXC}!8=UnUVzVOqRbCFgN4s0a_dG4 zy2p#lY`(a>ZS3#M)s!N)7NTOaW5VJCywcD^)3*>#VC(X5&eri|>WBf;6GtZ#1shE` z&pwq`Q2y9K31h0~oO5_-@-0`m=KFYD?rGJ<BGcF~c1#G)7USJX1gJl!Nn5n^y0mWf z&v<DLMiPOMb=z$J<fiex3ipcMowwM!3c@IgrAY)`Vo`^lA2#o?)xMg+s8PFz)?%Qh z;!k^5{rDR+@?mdMICV>o6m}LJx_56tlOgOPMjqRL78AhK?Hnq6?qosmz@!SNlX|NA ziJP{i91MPvTzoqxXG3J)r=JnIT7~rca&#DnTfrnJ!d+18grNv&tOB~Xy>eB}QEBgO z^-vpkJT+2~Jbgi(TH#&y9lA@_QFqvQWjpBk!yK?i&x3{LVxj{JG4OT-o3R?w!={x* z5At}!Te|L*6Ak3%6f0w_&7(;5PT7N6+a9;Rq9LTTO!$gUgm2i{xA<Zr;A~4ZyO2hJ z8}quF8#9blA7$Hf30`O+>pXZ|z6ZVi_^TsRVGv{>;tP^TQ`ov29b}4?3zqIrhE`ck z=_c}QsMP*s^GJB5ee&}t-vO(EUAKE#Gr0DdeN?A;l#h}lNVw<Lh^>!l-d&V7>@vYF zOj<u3@3&XX?_y|t#F@^>Oj{?A?kM;44%OU3Db753?=4njUfnpss-oQ`uqes}UL8<T z)tOOA`B+@%1IhWRb4(@N2}{O4<#pbMlPh?+@%49zZxNXT2u@rmc6HI38bN~PpXc$_ zTBvvs%k##@2IzM7S}O|rM77x)DA8hOy(C-~0#EjayBd$b?7xIfEe<|}Kj5P3XR#`~ zhc4BD6+wGdZ*~UL73zC7=-PhZRw{g1dl0K$=xJN=6K)`xirL3KXci&@Z#-8Bv2Ig? zfs(fFhHpou4BXwnJLCwq!Ct)P^rP3lJNO0{P6PyH9}t0c?hqo7Z}l{P=q-z5R-;;v zc!e_d?cxsBdG3s1$IRJ=tgjMUjGxD{%qg-faRN_s?Yk72Q!cF6cW!>%iQA@+hcjBw z$3zQaxK<ah$WOPqFL8!F4H-YE^N4u78i$UTr63Gt^+%H@BIi}DrR|nwYqVd)>o+|w zRTkFy_`$<BDzTrmF{=wZy8vhXfqdJXU5ahf_Y|@=x-ezzwPt46iJ)UwrJ92#V4i{# zQMDU>1kbI1*}^LD(yZ?u`rFMo>}-b`|At6m@P!`>8qG@4l9MV{d$xky`W;Os&u0co z6W@iLU+?kLn>}->KRH4=ILmvywQa#UW|QnmiEX^?2A8BnLuzlDL=X3p6rDn9(n=ZA zfYRaVW{jS-br}(Wt{qA)qev&iYzC!YfboWL`}c`3^Q+T+1=iND^sn~Hi__JMJPhRU z3Z9yFk3!$Tg=0rv8ZCa*u*snx+<exV>r7hJ*C2GG;g-5JceFB7`_qT7>q)D(ye1Dw z+{U)*p6n&EZnL){W1E$ShdKyQ>>VKlOYj9H@9Ng)QkXP#M*-S~&m0V|soorDk!w1m z`%o^ZcmfrIu_#gK%Uqp&hnr5Rr^xz)zJ`}hfMC=ulIQrrTh1A_8xur8l)C?Yv-9+6 zi?&=k#s1`3uAL|88pia|(mQPKs+)N=1~1>gyIP;YX%cDrN`!(_wdFH=Ggc9!-tlOa zpgN!_HYL5Wb>~@@dwtr?Yv(?8-KXM6mU?(AN{f?|EKzDUoq!nPil3PF+}142Z<a#e z4GYHh7aA=d8FT8)^Qvup;|ferT0M1=n+3nfbLt?IG-=_Iz8MGW8b4H1I5@Xs?=}&* z-{j2Wx!_GX6`rx6|KZxo=egt$RFsY$8vtVz@EPjswz^sKrfGDNUu!G_JGP*Wwdmwy zMmr@v!xw}sgnT;}HY;x3LN_E_wB-2}_OQvBJgX3mw8EB@xmqQlSgRiglziM&nk*Fn z20Xjx48#C5UC4gL8T98c&m%)@GD0P#x5b>%^$p*18{-VWGZgp*{?;tJLxl3G_jc96 zef%f^p*~5V|F}yZhhpNMZ%%({GvPWrF?QI2Q(_yP@lI_t@b3<<xFfeZ9G3)7%%v5^ zHWL9zbOjM$9lm{il-u}71Iwn<8j^#|{Dxo9v^{#a(A3}Y1XH&mHjkvGF>spuaWc?4 zayVq0p)%@>>4AslOO&V+R`JLP3=v!O6lKZ2;sV``IqU9O?S|>$?|Uo-EI&K<6`z1g z|CvJn<eeC2r;+1}EmqH|=5M)c**hp}!Izm*S)qOe>4-#G^G4g;FK1W}q>o&1w)heP zAI=>sg!Z2!j}^$Y!{j^*b@t!uwMJG=P3KG9JT&>Z!Jk6iUt?fmK`P8I6iSaf!C$A` z(jOTiLukb$@N+eb{a|M+GE}~`H{scw%}ryIZt0YC^43oirn|SMicL%>Ts!*Q%EwAB z&7!^5%iiDj4$=8OVRRr`QL$7wy$3%hUeUpy(9A^yqV7fXAt+zspVuHJ=*BwKGGj(3 zw#?D%helPFr~UL}PhU1{d1P8k_x1UWC*nO{S$9wyLPxIHyYpk~m2>oQvvfNrf))|+ za;rWrx8#enRAd4reN|;32!&qq8#1~plDU~a`L}hA+=ednhW5G#y{`@>u`2nV<+W8c zgD;JUAL}Lp)#P2?@%&WHLSYK%<TlU|_h(cGHJo0BXPOU}o$ov&!LP)z|M5ng&PA!j z6RwubZj8B277P81I9^cv+%RRd6oPffJMmHQVn?l@`nxSkBH%cg?K&pV>kZ4(%ku4H zQlaGhIi)nUkK{qR%qC9gB_kcPL!BDX+_;M)G@4>xmf|@1PpqeJOfMxqJNHBw$kUY= zS5fIqbKN$*=e?IwSCQ+jh}_P=cf&WkX9sL?xF?5?bRrB0qq&?Yw+;2ZLA3Gu!I2s^ zdhDVxXRzbPZlpWQ-l#ukIc{28gnBDlTubN1KnB3^HEW?MWC1y;uqTc!K@ZtrlQH(1 zQaDY_=v(i+HM!~4foo?<$D}k}Miqokn$BRVnkmm!@F`_|&V1iH^kn?@%xtc0qnBCW z$8R~SqIfRU(vbgNbR4c8U)ron_|hziF9?%8;wF6Jx~Iu88psPjLfBwn15$m?H!a^( z##m61lwZ3)wTa2E3Twx899ad|5vUlC*syfG=ZxchgNZ<!-HC>}8Y9o0>`+@hgZEbk zIdD~?x^WANl8w5lkM0PnY>TEfYz%Jp_RY)&?i{!v-i~yrSC`|eQL~<&<GD>~X70)% z9&;Jm8NtJ;Yx>KjBi)u?JDD0Uis#4akg{Ap`+xG{Cl36Xo2Evb4r3oJ7|mRyZ&_C2 zlYS#=9UmOccAz(NaGR52J`E#41l%h|O-;z-aO~KDe4O<FZ#rr8`|BjLUu(5>r<I>; zJnqJs;>tc#?!8{znFX-w`WQ1Jkhs$)A?fLDn%vXSVD#bR6B@sU32Ii-2S6oMc;Out z0BZf~l9V(XxiKWj<NXACepj<-qQ<Q>!@O#$aOzQXmfU@ZuO?(N>4g`Z2*V?qp_xmj zL;y@<5C;|pIxN6t9TvzHLg(1k%7!6t6GQJH?luX&gp`x!{0|h-vn@n`n|^%1Ju9;X z5&IJ#Z|ol3Fj&J$CJ?+t*bXrBa6ZGjFUi=cxm6DDjDVnH+NDpq@|_wI94`BkVRT@T zm<;O@eE8is{Dt#-z^lBjKr<I#TDP~{#d&l#U86McGj0R#zDNZ4!)(VQB!U+Wu-0gW ztB=@DwjbtHcoJT#6s)mgqP2!|k|s@=C_fKHbmUUdes5&wdi8+kfMfwK%lG7G-6DZJ z>rkvE2V~zieQK4Den@c0b1FLgiSwfvpj(J(Q_r(~u?D!sa#;spb#KZ23H3h2p@_g> zgIn3ZXKBD%3EE*V7b~}bh(fT)glQFaNvW|F2ah${u3M6DQin8OS~^NOKO59w!V%Q9 zb;jw3nrPyJv{E=XqVn#QHx1lQQ45M0Vtffn6!3H)f*&ipaPz^a)mSH-C74^s(^-!y zV<N7wHT%8W95Xch;YDd>0rMPlM#xVeXk%7CK9fKboV^&-!R^iDHqa|`zi>u+uQ*hE zN;KWN@29r%Y>N32A<mG`hyI6W*LpWvbZw%w6-jAb^D<j%yTyC?vw)@7g(o)c!e2j4 zvN)y9n~JyIksLqm$*zaq%qB~Geyy$OnpOGNvW>-S`8qqMCIIdg^iq@Z-Xm;CaoA;0 zX=4f(lv1ITfv?+BdU<zOHmBsL^fR7cDFsNh6gN#xO=ILMeKI8*(7EedN9JLg-jHj- zvJ1CfHgjN<@6TCNKi&nt#Ol2I*%ocTj|KufRD66WmTPfmj*Q*;Sv6HwXP{UG2eOiR zYL!oooh81o9#h(`eX%fmEdjyIbC$6Jo;JDp8duT&^N7S6*^Okj;HyR>vOOEh%#%VF zKF|RDiG7hn!izP5nq27I`#A0*TA@}eqw+*Yx-HQ`mPl&8(sPiEGu_`tr#7#wWDtRD zLXG@VbqE+B(}ZQC-3ga^7wy5-h1(W{Sk`n*9pTdc918E(1a3;%HTS(m&&C97SREmT zcHUXr8egT5+02h*3XSAw4XF~RInsmrs_>qyYvw?Y7L|+#-@QEODN|!~`c%!k42qeB zDLC!Yg%4HE&4UV-S8bYmk8;t-PyQBW7TZMN0}&8f=#N4Ej6gEO*GWBz!1HUP-WwLu zD{bD1a2mpyscXF<&%ZCN2a?AIJvQqFe01e|CDXb&ojSDyavP*^vh&$Sw|;aeW$C<7 z8(E_$mra>bY|K)eATc;)yZ&5aj!6&t#XOymYyW!O_~ly0fSr|1UedLjIw`7*^%5UH zM3dJGzTVJ%Ru{8aLImW$4XqfvN=8uQlJB5gQ4@&yn0QUbX)O2r=&NSYu`n9}aI3aD z_FI6(BaSDmX{>%^BNs!mKenRvzeR`z+wHelt*n>KpD`hQKe+ikYF=xK%_*h#6OiCY z_K~c47ro!=eV@QIoOun6x>sE$iq|r$uc?YYX{v0(Ysx?$#^?TGBusYx*{vhv;obLw zLncjwHoeY@9Zi0{o+kt{?@t>fR=xIn!AJhQ<i)TU=RwiFT`$J5cA_p7fkd?Ma}uU= ze@bA<37w$`JiEv(L>)aYJLUoM$NLyUx*o#l2iQGhaVP5PFiNSeB^~Y({Zw{ZyscgG zQEca#$f)L+qj?gB8<!)WV{?sCbulv8o+Yl^tg_xgRL0b4C(^3@NQz)2@+umm3is!G z@jt+wOZYca+rqvj?|xi=Z~ue?=x*=AII*dVMl#`7TOO|7x_zrudi+trxw!9skFV?B zl8jPjnSAun$e4W<{zf^b^z>?zdXX*moymw}qV#&eyLWZ<n$AY^-%MhJfyC{Uv(&@M zBkGCV`b~q`=iXIDg)1)n<QiEpztfS}Y?+W)=-WBv!1^&bXkNB;i(*a4;+Ih+4z#F! zJ)LQ;yz<JkH9bXMPga|jnz}w2ar@oZ2QO1LQ|zo<{a>O#XFBI${Vxmz9649jys3$r z5f?_vEhGt3z3R`HGNoazU*YWjVL78o1Rk-sBll>*VK5+V>*PNpQ`tD8Ay(#^adLTT zdfcXs$x&t{MbwN_7?N^3HSTJvXDy?-xf&Ome!{)Au_#`X$<4g+G;LszgRrtY6V2OF zqwTq%6O<}zzFis?r~%8*%H}dmZyUBR-9Iu8lSTq<3)0$G4MUGY11+o9`K!6_Ydv1H z3G1=Gv$)hXBN;`W6K0Re$-_YRkOMf|SFW`Aii8rkfyt&ES2=Q@<oXw3DxCbXLNG_J z!z5%hif%JdsB|jW;bO)??k9zsr=r>dsx#NhgPfo4mM5DI1cdtyU;AOs;KtU|vmp5( z#gr~Lu(n=EsB*)qK}<xnv1nhtim5up;%4)%Lk2bvT$k?cK<-h#2W>%*;KardE*9w% zf$~S7Cpe6KryVIP+pVs(cvINX^nIY4Cytkw!#+Lj8-pH<kJiCp9!}MlFB|=Gpfp$i zqGy$RaNnDjDLF>lyR`k6nnwLB;%z5zvV@y)=1V_O3z-KTcH&<)HH^`{SJ4|UH3y1a zuS|BOMlKsrnsv8)=;mB#4!}}kN?Q1Mup#t!!@MK1$SUUxPBXUzM2k8xDUjUdR{(BR z;qf27`B5Hmnq598K=t|*cwSYV>RU}k41p9QvbC4|K;ul&nnmcjkTJGOL(zdRUQS}r zHzjG$zKtP%>?$N&i8Bret3l_7G=(gpVv+1+(l74npsL-uPYT=S_#^^1p*(ZvUmUV^ ziK*-$#<dCh8K5O>%fa=`(6KKWdu&P>u0>XpU9>M;mrBJW9ZWlw?n42GS~(L}HO)uQ zzrK8u_qC<SveV$2Xa2PtSGt7nodvp8H3EyRTxzQjiLF*^`=F(%`4tG=$e-fa0zxM> z3eV2!r%7{cm<3ew4DZ{n2}RkR;3^;GkC{{rDk_{p8Wvw~o5$q4$QhEw6iu!UOz98l z*A56v{H%ql!vHoFPO(&-r2F}<nVDHh*+r8Rwrez8%P#r)-=1Zv$jjfcNqEX8R$RO^ zjtm0V%MHM}+4(9`8QDPuo*Z<bv%%EhZQx3@@EdF)<zSfci*Bi);t+pdFPj(-7BeSK z=#n_uoSx7_<01aH!$_(beuf_?ngGdZSz}AK>CMd-K}kQlr^aLL^sck;Jbj(}nC^Mx z#_h4TEt24r{UgJ6KE9z{$tKe6a_Jix=L()n^%y|kl;?~Nj&@F{tBYTmnK@A-lz2|? z)jEq0`Czr5*H*<Y{#r+2s5Gi#Z6!5WZxI=V^Paa2?U+!W5UfI@)271n<+xcBX|jhU zmCWv_eC8geT7us+l31RdAD)pOsYhXKcRv@YdZ|3Bt-!(VPr+x(3l@yqTx$?9T>W<= zxUF0dMLMfimmfw=h>{NX^nL5}`&8hi|1GGzz!cq?=Y*bEY-y$s-q>6iGD^(u>A1|H zvsA*A_n^xzCM$ZqP`IWFQG)ZhV&NNq+ZHttL7Ss`q_NVWHkQ;Z*_(l2aMPU_c+hny z@_gc+^YAJ1X?JiXKNiWVL5hYGs4D2=H6`7O=lhr~&^)(p?bk`PPJP<Wu>B07D(ViL z3Z}htdu%Pab&Cwyib9jc&x?r<eLV}?;1uBs5?PYahefUXK2+H-T+F?Xq=^t~@$|%q zwW{gGEmYk=i<a4YhrM}A!_4uF-&LrXqffPSC`M*rhzN-I;=M#<Tza!NkXi0OD-XJk z^bzB3D6|Q3T-$W$rggSqudSPllXiNFs`QgbY5O!3f_135k{m_MeiXPq>B(nIzIK=t zS-6%XSXDJvvGb;NfmdFCNqRfa+{~2~!kwnt?l}Juo)-kC4}*l1y5yW_LXSStESY~i zAuX|KsA9EfoH*tnQd&seT8cW!m=5zDY^1d|7_8Y)TG1Vo${P+i`Z{rPqujwVFGGpL zF?fpAN$8cp-OJrBetOmhy>`X##t9eQjE4O)uN4*<ik3GY6&z&aUuuvN0oV{a`GaHz zmi3iw@b-dTi6={jvd`>{&GdO8$BgBlgcG$Cz6{~3!WYK9(=A9-*XSp#+cYI>@iB0x z+@u`VL4-A!78}w>?`Ho9KIq6A|F9uMTlQ1xk#6}{b=G*sY9_eNY5MW8Ir94+yGHOf z1Z|jd;lz)J+}7B#op}9}&8gk!z-u2;6sEZ)H|z>tx-=@q7ffJ7b>Gyz87zHn=pL#s zs9!c8rAYBej|qs2doFGQK<~&|tA<t?M!0m}x+MM1(z+2B7A5OAnvoJdD{qMWfpK0b zN=D8V^H-swBj}N>pU_1QgfN<KgBXt3V;p3?Uv>Kh%h^D=@lQ-v0?Mfv6!;%sAVc#{ zWOq#LCr8K<+9DJ&dNJS7FH6eZ6GJsqPe)sNWQ7FmGwS#8=-rM|hH^dhqkvhgCdD2p zMf7WOVR-CLq4(QlX<Zw?F>&U<uN#yWtWPrkY%6@NO_C?a^lo>>`LQtOxqg;4wM&%? zo>1eFtsTndp&}bQs}5hqv=#pMbs?--Vsl|1wdVXg%uaSHo-m&@fTyf(Zyqsib&JbX zrXkTH?kHm!Ld>#72NQgL6G=giH&90PUMWGM%@7An&D2C-TxSrzW7ZLoI)!C#VQ-eM z#^6_Cs#kHwMIT0MCrUFO*5KGS-UVEqN|%4(9O`6tM?IM*3&DcrZn0C1?wswM;UB_j zu|4t2tc`tTrG2Y*|8f*l@V%ZE5tB<}HmAB<3Mt30AbPflKxzY_mgd|_x*o1u7WYVR zr{RG6NI$sKcpOTgqRopK%26Eja>n{D_K^lq<1b8zX0Ka_>l%vlnrbs|ko$KnbF@{l zR)D@2l@T!sr)<{5@x8@*qh+isSierJWdt2|mS-rNSYCS_%QnLnvY2p?s?J1dTYm*2 zy^qvGxxa$D<ApkMxG~5N@5PL<@o0jI)oi;RRa0CyANuYI=CY0+b_s3qugU4bUk>~v z^~pPDnF;(6qlnu36*0|J$}QCJJ1E++UG;Edn|E286Cp3<5^lhv#bbTSzllHC^vU~{ zs140wqHHt?)pgygT3H?|thiNHGM+tL(ygqC*?qTEn{#Vm_R{N>jq+S@^OS9Nw;R60 zXuhUWaP5Zv)y?Nrns;uzc%9vpty-)*U~r%1@|9~~vV5E{|DU=b#KHgB!MvAP-MyI^ z^=EYR?Lj_MQx{%!kuK(#r(1^W^;a+EaTJ-U_i46W*5b?aWbKg^xDYqTUjRMJ7pW5x zgcF?M0CN+aTmrO<2(%xc)2NGve`j<Ccfm140CO6$cG@Bn*(yDsd<e570&^Nj9C%DD zKG=Z}$%e$z^MmL6&hi!fL+Kwo{o@7yjFx}qf`8`Bf8GfHH{P85#36iAj4ik;Bixgb zXRA%K^_imY<e1`o3mY;U+)>b*H<D94dR$t54WBVMaZ<!gm|fmcFh#Yir+wl5`C#ML z9)R$%+9Sw?F?m5>cyw2URkQ&K-<FT^6!ZUts#DxL^vBKOTnJNRsMY>$B&Gs9{B-YX z>+BBsTX3}#+~UE4NBddOb$G}+c(k7nP7p14hLl|)IQccNJSGBO?z1~92Po_SzM%p+ zl)Fp>7VVya!A8-2)OY&d!3MJR<h=JGbo|dagXhNqC<2|u)HeS{Unt=z5l|)qT=2hy zaUVkcJI7?Kj)pW5I2#Jbktc&m=N}T-_Z;Oo{LNNxZy8*Fxb({?x54-)hYgH>&iqUK z^Y0@0S1jR(z$aCN|0m>MjQ@YF;EuI!eNZ^7UuDK?P?tcDCrb*%p>hvmBsW8H|An=R zyPd!pwl73rHgFY*jtjoFjZb#^Gw4!GFAY{O$oOn<I@D#|Tj<KbV^iJI4?d+OceK6- z|5oa$#a`C~WP>ZXmM2d5ilN_^hW`|~IbkO~U8-y_XbNuR^dTziY|mlx(sCtOY36vm zq%M(#tDOF=kuS1#7c&9o$NU9H-iV!Quz%ePxXjbYG`lFfUY~?NVL>{%BDPoP@ka>s z-)vX*W>BlHzxuMS;?OF>o1h4WjsJpqBzb`d<f<UIz|#r$fA`@Eympre7#jBxfzITO z_sGE^(2L5G{|%Xc+{yxX04~QwjN!yr5WNRP;4$bws`x({okn{8AY|ko`*Q&DUvUXN zgQ%6y{6~AJc|A1(q950mzOUN<7ndTW2R7y?Blm!*k$?TYBq3~`2%y{vtnnq+S{%4$ ztmvho5{>llqE%pYUBUssHGPQm2>zql-BRqmKtp3(tXa^}n9jTrNtaDss>O(jO}kU@ zd@O1tl9}bE<Us6oJ&hZ`d;13R{~e!q1=pzCU}`!KZ+_S8=eq=-``gL<KM{ey#trDI zlbR<q?o|_kkLka>-kNY{@6Rm41`)SI%}Y0P25~{lli`t2mIx6rKVMdPP4=udgUTYy zZ@2XqI&x@-PaP9U``tsYxAqH=?|(a64Zd=Lf71&2?E<`3LmS`wN7oIYW`1{%)$bw0 zKgoB3uvA1i|N0&N`|Q?l|B?PpJb5FTfe844BSp~6mayZ!|Jy|}P=Arv`#l&Xz@^gL z%0~mAQJLFB;M7uhFXAr_1-~Uvw?Mw<#()3q<4YBP^SJz5_!87iB4OvZ=6++pnfyKC z&1ScBL58YhKMBZXo{9$d-#q@Y&QFnltn=sV>mT0%NAEu)>z{e&f6J^K^)IusKXDkJ zg|(Zk*oC_r!NCJRW0e*OVG<s*Yq`6tMMOaL$M}y6v)sFaxvj~)$jKiUM{y3A$vsc3 zt@A>ef9s)f{O*~XV<xHRP${ho`in^;W6L(h68=#=Utx=7!I>MDX!lOUo7u?f<O#{S z*-mdxH|0jQVrTE57r~k4YUvr^-#%e+Nw5~Y;+yi}?tRi!hgWX9^jL9ZRPKW82KA9l zu*4bv%jhAu#^xIKy^Kfj@rsQ~-U9EsY;p(8ov$kW+JoSx<J%+nDiREe`tMdne3_BO zz&niZ*$LF2a^2?>A1QioE89+`6|*5kB6>f&_k{l09-HQGv3EbyZH>R!xSfm;xjEKn zEArNpm!mf2eT7?4?W~&{5s;p>7TN&={gsWo=#@Fru^0`-64%?PxG?vHkGN^&w%fEt zlYs{Lz~;kt&U4QNB2ydUYCEncne&ft?{#8B!R5yeuGzL11BxhU@JbsdKE4dvo`JL< z*8DUnrzX1G6M=z{yVHKym@rFjLGheMJ-ADt3iYOs))cY|lNrWo%{}(q@w@`=_uai3 z2Ogc-lDc92gb`NWv&pINL3@W6%qEI)aI9X*IAMy96XV>0m>|E1*=i3bKli?lOFA<q zf2f@9ncxLTea4{kc}n5Q;fwNqB#{r@Toi>1r3?FsK+3kIERNI;zm+dvwmBmiP=Nh8 z-8}iux@u~-{{cVR#5*R36PsL)lf{P2A=z+dBTeGB+^+;*y?7MfoGX6QJlayZc&3m; ziQVP8#jRZ@gag*;$R>;dTZA^ef+Ye4?y^O}7Uv9-O1&3c$^%wC2OKjv7mN-bnT?SM z&UZtonGq+z1Ei-k4CiT?Kfsxqd4I5<awTY@i9q~rY?C%^o0J6;)4ch&0tH3}3Wv}7 z=H|QI^GJ8_+$xR{TY1EQ^BO_kl$12gc9*_A(4^!4V)f;O`sBONA0g#M3KD`-^ZMy4 zSLTZPo-&NENA#($EwV?%XHB3sOMZ-LBX2&lpmupNR#s8B|E#>xz4C`Ivc-hJR+wIt zpqLy@P#cWkL)eCA$+Atk%wog^P#L1~!lj+^(@G_(-(XTy$psN`!q_YY${9`(Y}iFO zow8+u-b(OtsC)Z<a4~63#K2zR#l}&GG_$}*xfn4alBSZyu8yrc;Qs9!+1RGb1#5@y zSanNudPj6JlZL(Bf}wfz=%7ktY(bjDV>FkNcD+nna`G;c-nQt2_g4jNjAsESd|vzu z=UjdF2^wjs6Ok(rrLdysb4eq^-)BmmSG)Nn&h9^ZHWui<GrR12WO&ugX;k4BV%lJJ zf6{tAyuSeS5-8#xpD1=3+;ij;!+Os5)AiBFi9M(uat$h0>{(`XqI!Gqt{ZxyHp%dm z$tSob($jddq!r1Aa)PsA`OUEYi*OR7g&(1ur`2Z}T*)TduAG{gv|)(~2qI087(Ch8 z`W$y4!M)m#1#c7Hkpza|l8%gxz`0p^+_`}Y#igHCj@GP}cdO$XdakESR&Pp?+d!`p zsM0a~^NRcH7O%CG-WII|X-cLVWw+&=ANcHLD|6#XxpbUdk92+bplokV?aB&(xuIAn z@k}+%b5j4^;k(r#qhxJZA=mKHWEMCeu#)cY#<uWuSX=N^M1L!MkKzq@U@oA1tsZe7 z{5ji292k^ekEh3rUDzxc05=Tyib`<(l29DQZP||p1#Lf(?MZEfdOa|s9<!WZ>BH~i z{bLTgZQwEAz;v892JZu=<yz&aHt*~sB5=4uzMaxQS5PyJkR=Qd1cyC}(7DmN6{7cF zxM#|X;HQL3r1Tv)q~z|mRSK!7tlV*5EgM-xG2^bHm#H@-!6R0T9vAl-R~=XeU&p&V zQ|Y~6Yo&B>XR07^>LKNekqJ9wh~dqC+0?m!EEXj`2`ux%NuwP<u_r6jrL}c(be9&H z=Bs=kMqE+5zVJS|E0_JWH)QtO_bcPJ?p!nZ*$b6#lKawL^?%aUY7tljQtx(>hV@|_ zKG?Bg;ZEL*i0BF2q0*}%UBS23$%<Y=V$Ys0#U+(BRWY`F+x{LQi785~sff)AiVRcm zFT&p5am7sah6hr0D=3sN$DI&3qp+`rZhuNuGHn@zmpx>5S8tvzOFoZAb#5@X-?e-j zE&;Alj0UJXH)da|e<0G<c~`6#HkWg@%~K<3Y78-tc%a4}AI8zF`c~Gv=)m7VARDs4 zz((J6^158zncFv-y|Lah`QQm?nF6x63*t@2A>IWVSEjXLK7wyD<Zn#!_grvJ9n;{` zmAo^yD@g=mz64LsWjsz?aX-A@FleUL5`~QvTzN@yKY7i1rtn>f^u*BU9uc_v^Q3g( z_?d8N<d5OQfUscnJv$&~ffvVae#^pNgYApjiTe5gVSZ_|o<)J5pMaKrk5TkkS0ev4 z3~R;xT|E0`bl+aSf-63cJ|i{t1C_BVyseXRQT2nS0;Z^~t5qH+J>(&qd7*NUx2$8k zwW<rY*e&SF+I%&o@N+n22|*U!Sr%9pYz2S7&n)fj#QvqPsEse1Bg25YT&bpEinX3Q zx%Vma7xPC3y*+z`ro$(-!(ki9AH8{d=4)*jZe)AD@jjUP50qF<mXNoXCGsH-SLq3f zR#6GE^6$mw2m0@A^y8JmGf{(2+<TgOc4ecxG#P5)J2PM&ccJh(-5N^$G=d462rf^( ziV#3&uZ?-;6}1JIXME^CQEi>PL1EFyu=%rXL)dHbDU$*sKky*lq1bi88OIvh&g^D! z*$Pb5Xl|IfzPQ%ny}@w_azoySvcrr!)CF-gd+6l^pYLBH0)^-Ca!pOuLuF#avde~{ z?vAcb8M;)DI`8@`$jt#{24_ekJnX=!R+FvU`!%+p6-gK7zfb}B{`Rp~)72uE$?8JN z-tJ{XDGx^S8F`pcLS%aXAQ7lJ(}fd3y%p<7q;Cnjyk=F_e}C60j@rU<Ic!oz;#FdO z2DwX@pUN5f1j0uo8v$k=(VD|Jru+8gM6i3CCVx_|t%2QX>(Gz(9qU&+rw&qjEL$Hv zIUm=<p_ZEJC0BfT1NQ8FNR4-G#z9(Mgnl7LsbbZW9sCijM`&(5%5651K#QfN-w+9} zh(|$Zv{^DT*($43ZJYWk1>)EyW_*3j9_>BuseAH9l5t*>9^6{KkLA9GK}CZ<mm{Aa zOVs2{=KNgSsFb&0E4F^iB8chKwFM?hj%Ys%?;cY}wiEhdbFz2ABg;6?BQ@}isG`ZY z7Qs8X(=n_Ow8)((d7&P|hkSebDD++Il<lVmjzo>jmXDWHPtcH!&f-hK5GBcu2NUz5 zOkq*4%T}_B*)2yMjUPcpSL2yRKeP)yj&8PoFumgP1(gN`%1~!4Sex?PP}6>KV}*mB zA>D{j>t26O#r5fD6IX(4JtmA_Ce1V+@)z<~8*ikdSN4K2GLdM)mi}m%`(iRvaYk<& zfwzXe@$Ok=!a;%5<8zmzA{!o;FuXH+5<#B2mw+zXfD!>x8hnnm{(xcy)sJSbT8v0- zSwV1A%eo@<NwJ~A6n;6nCBgIDV>FlK>SfMW+&~i==8@#c7PBx<@^;P}Rl`jdb9f~< zjg3uutnhicvW}*UecJ{??Mpt86THZo<z^Xd-TdqW<Xj1yq&k;+YSJz^UVE_SvEmcy zo_9LUof{j^(}#1PL4v-9To!=0OZcdLPMfJ)Ij}T}$U0mttf&ce8HxLz`FTxt7e_y) zwxB8aH24B|%x5SRjz1aq?)H3y@YCRP;GBP4ArbfD({`?WvV8GDT1Xm1v2WvjdUai< zU+Jl{2|v}JN(e?-#-0KH`0Cgv#M`0`+qta?we5CuF0#zf{QkV^E3KKml@--TQ9nr6 z1@2t*x*QYU<hX4sw!t~)%?BP|ws?h8LDS`VzUW&}JyA6*D<d4>a=L*Vf9gXqPxf+f zWTIKE3`fc$RKxr&nq&oBT4W(`gy~|>dF(uB;=eJ&J>n_qoalv9ySjf_NZ5~Ay83mM zF5Q>*XP?#6(wDY8>u~(dy#w7l+QlydW|3IF3PhGIw`TSBNt!&Zn<~)a*EvohVqFOE z1&zsUO=b+%o)6r8`DB6DYie94C`7oj%bVq!@XwR?d2t;NP?>i$!<cj>Zce`+c{P{n z{$!lb7cI2cgcj>U@dg)tnBl8zc$29$tC-l)f{?VRj9W0Ycvsyu6YF%Ak_foay}NVx z?o&mgB5M34f@TKCGulku=UzCAlL;7l$?hraip;YS%M53UX7wDZ<O^_$o;{tR^4T3c z&?;Y(l32YO_+uP5(j2tvzMSo~VRblJm)T8VlL<)+NjLsf@adI!!}8^2GdS2!cYaqs zr_OQWbGo`MzAa9&@<i{6DXvGO2mJe-{~-Rd+5b)R&=W`Scd)LntyT{g$8A*-YYYv2 zDZQW?q(ARY#mZ}@Z>9f)k4snVd=+Rk>S3hVh0#6VQ)GBE{KX?4G2)=}y#(n4@ZF)5 z{3|o{$HcVK$Jl9NV(x{7E=&s`e$2qKLKE4(^=lv<GnULmyk%6`RFz?N!sbv;fU5#? zcdnw1J~@I83E-^16ky2i4R!blgos3@1p8N=&md1Sy-o=$Nhp@^dvSLGbp!mV_V*hj zD75AxINM&{lSUAtrLDmqUT(s%OOAtevPOiPO`H2D!?2p6A&dRDgAvk%J_Jlt4QqV~ z9n_Ib)v~|~*E6#g%d*VaoC2O71o<{j$)MwxBewS+bMKqcM}eosKJbr!M8I46$tS-n zt390b4RFlbTsuE^fOKh0u{<AQ5>|6PJ9^zYYd)f*1U&xVQG%Qc2V+>Cf^{ntp2n(< zKt6@U^wcP2yKbR!i)lquLZzGHb2;Glv*XT-s5tt>W>GBN?By{>lec!(Q?}O(eOO4Z zs#a(4E2f)@yDFbDVc`KJCvnW^qIpzXF0>0h)}qN>H5ew#V@x@;aQ`BcaDhjTi^r3T zcekqVKOq7Sve3qrV4l1q=QfMw_ay>d;zw*1L_nLZj0i;T^W*&Tg!AWW!^w72Qm1+L z!0j^F0zR%9#U`7u5y4@pcz9`_<Yc+GSmn&Y$0L@vHN{$~4?bD5S0&JdD=&KAw}W88 z-{U4;?X)r(mI>o5)+<?+T`{5`udu&TEY7G#{z2@U&&SXTk|y?l(_i!7jZ<>PRXFGK zkM8~5f$-Lp_giEBtGVpK+O1%kISmGcIQK&9A6BBpoN;TP*W{$+Uuv~U`aV1>unQxh z5a2qX)&}3@-Vq&kd~Zgof6zyp59*X@?YZCe4MqA(bkxH)owDX<$UbrgTk<4PJH)O2 z2lIdLnf`S5mxt`1c7Bd6ha+CDj#}wHbAPc~N#QA#uYV$e>6{+deY5Uv6@c!NTlxDk z%V6Y%`VsG^!al_U%)NE!SG!WdtUt%T)*uf#YP@mu_29}xnrK>?n#~$}rdmSf&%{8D z)bzuUgs&xO*+GuFCnbd}E+lRQ^jbR~4V;gQcy6Kd+)l^)bstSB%d~DJ<>o-f;#!b7 z-C=a_LcOuf<ws+y`py+zihaQ)t$bXV+H8F4LOjJ*S>#k+RcRvrTi}VbqbQZd!jSz( zeg%}TY+oncvC#*>PA?8CK^;<p3D?~DT%4I`Mg7p5W_~O=x~}bZnpI8Gqvfd&Ed>g_ z!aWIUr*u6Vc(nxPex7nK#ALS7lVA^BWSYzh%ceUT*A2{OPLepQ8|wDx!%`q0VNU}U zKOI;gcbMLR8?>=->}!pG<)d7_V8yE(9l?&JZkdQ%o8cN{>iqxqR`=}o8IwFbJGOn2 zFcvs6!7horfUWRoW>y{Bhj|+xMaxTMMHQW2{jy8!UeP-<22Xj7gx;njO{WSfMGxb) z$@+C+w71;UNAi5YZqcQs7j0bAWxr0jUuAxtr6g6EXN|Rhqik=XocIqzV1~Ih>$R0w z&&GSl=T17QGFjw&sgN;4?Sb=q(<*WLM0N2l>8p|RPVcyG_-OWyX~4zX{AQ0_6szC5 z@%Ufv31@x^4B87`@BcF}sQ?d3S^43hjLs#oOqINKZzB7z%$db-?S9<U5Br0+PIjB^ ze2?{R%jfCN-@AOZz1FcWyxa4u>(kt?Gpm<7R%+d0pLFk${cat-izoKG1@XUjkdF>O zwO`^tLsI_D?#JJ|kN4XF4^H@99xHb7R_*fN?g3l>9#QSmI~_e^@q>-e^ZY7pwaxM6 zv?Y(`J@#E({9yZ=)ZeE6&dX`;{J8w!dv;)lJKKC(f1=KRhFjv_tJf|mnK(K1>X8*j z9~YhoxRY$5BBy-B{)6B88^@2@8O}c*b^Nfm^QxG0+l|(khfllKuQlh%EoZCfhaIQa zPv6>C&OCYUe*9(5%5Asum)(mk`PrE<E7m{v<MxC3t@eU-GV@#N3_k4b{oZ`U@b>X# z9e-vh^s(L+*#4<Zf+zb(u3%8}yNU7kKZ5^pzWy!qLvziu5AP3GXXtJD>vlbK`>yKU zH`guNy*c;8rWj9aF99=d7wOYQY+rJ?E~Axr{iu%9tX;0RKA>uKiLLbG`Zwo4n1Acr za2eQO@7X8Pd_;Sxwe6B+k!#l}-v548eOtv+jvvyK%dRO187=4we)#@P{BOfQn%8Xf zAL#>kMr^zIHJ;t}%dFfh<>J0Zc8lew7F@PIbL_#Edaa-T(x+MUC#)e*7|*q|_+!PF zKQk(7eq8v>nt&^BlP)g2Xx>~}Z^If;`(VO}O`9ja{FMAE`v{^GI|QokLas%Xh!yQU zS*fQd|K8hp%Ov^LhU#CJDA(6Yznb^U`~vIQFKs@hropv}36AU&G7<{YAH6@cz{Ae? zqN9xbj~TatmD24CdkP=C7x6dU-d6gWH!i!BUn%mgYr%;%9z2K4k{A^ld@El5owh%1 zk*7WP9y8DPf2ZdHS96BvKA7K9&-G90Zfx8iyAOR+AD-^M67hVp|CX+sPh2gzj+q8I z*gUx0Kg0XPSKTl9tEykz_u!uW;(o~Yi~Ab%@)y?z-*@}PuybGUSN~e+7fjE7nP0X2 zlD}Y7301>oa%~7GnEzk<>$182i-rAb;6R+#=6!oF`+hKQo$t3bZPUq&3G?rr{Cn}= zRbWxpB7A87XW%%&MBs6nk5<E)Z}Ycb+_$e0O?Cwuf3?whqSUQme*X4L{`v*=z(eEL z|M^+}BobJA*ducoe*UhNe*HrJhocQ}^vEv$r<nggt$q|<gIyiaTff-<3HWpUXE;Co zpX=oR3|)&+xeNG59W)w3qiJX~mw`$%^zv%793oVrkA~1_2#r>*BsD+cky;}Q`2S!3 z`wlog@<bB2HZdj{*im$iTV)mhr&-tjbx-{-j@4I}=J)=Z=>O2sRU|AGxEGAB2hGs) zP#a?Gk_R5jIvQHUgcL{(=$N3<&>}vxE{%p3(WMJ$x5sGOB0g<h84WEG($>|{&>|tU GZUO)|OGNnq literal 0 HcmV?d00001 From af4036a675f664632dd88e6602ffd39e4350ef82 Mon Sep 17 00:00:00 2001 From: wanghaoshuang <wanghaoshuang@baidu.com> Date: Wed, 7 Mar 2018 21:33:35 +0800 Subject: [PATCH 17/40] Add guide for howto documentation --- doc/v2/howto/index_cn.rst | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/doc/v2/howto/index_cn.rst b/doc/v2/howto/index_cn.rst index 0c534f107b..b0268907bc 100644 --- a/doc/v2/howto/index_cn.rst +++ b/doc/v2/howto/index_cn.rst @@ -1,11 +1,37 @@ 进阶使用 ======== +PaddlePaddle支持用户灵活地设置各种命令行参数,以实现对模型训练或预测流程的控制。使用方式请参考: + .. toctree:: :maxdepth: 1 cmd_parameter/index_cn.rst + +PaddlePaddle支持在fabric集群、MPI集群、kubernetes集群上分布式训练任务,具体环境配置和使用说明请参考: + +.. toctree:: + :maxdepth: 1 + cluster/index_cn.rst + +PaddlePaddle提供了用于预测的C-API,关于C-API的使用,我们提供了如下指南: + +.. toctree:: + :maxdepth: 1 + capi/index_cn.rst + +PaddlePaddle支持多种灵活和高效的循环神经网络,具体配置使用方式请参考: + +.. toctree:: + :maxdepth: 1 + rnn/index_cn.rst + +关于如何使用内置的定时工具、nvprof 或 nvvp 来运行性能分析和调优,请参考: + +.. toctree:: + :maxdepth: 1 + optimization/gpu_profiling_cn.rst From 6f50dee4d5010d67ce6f757934031a30c17cc3d2 Mon Sep 17 00:00:00 2001 From: Tao Luo <luotao02@baidu.com> Date: Thu, 8 Mar 2018 00:39:24 +0800 Subject: [PATCH 18/40] compile and install the static library of fluid inference (#7827) * compile and install the static library of fluid inference * fix dynload_cuda not in CPU mode * update shared library and adjust the deploy of openblas * adjust the deploy of openblas * * auto add all fluid modules for static library * use libprotobuf.a instead of libprotobuf-lite.a for profiler * use set_property to set the global varible instead of ENV * add gpu depends of fluid modules, auto add inference_lib_dist depends * change the condition of openblas_lib, and fix a typo --- cmake/external/openblas.cmake | 8 ++---- cmake/generic.cmake | 5 +++- cmake/inference_lib.cmake | 35 +++++++++++++++++++++------ paddle/fluid/inference/CMakeLists.txt | 3 ++- paddle/scripts/docker/build.sh | 2 +- 5 files changed, 36 insertions(+), 17 deletions(-) diff --git a/cmake/external/openblas.cmake b/cmake/external/openblas.cmake index e2b7ef8d54..8af2765f58 100644 --- a/cmake/external/openblas.cmake +++ b/cmake/external/openblas.cmake @@ -77,7 +77,8 @@ IF(NOT ${CBLAS_FOUND}) INSTALL_DIR ${CBLAS_INSTALL_DIR} BUILD_IN_SOURCE 1 BUILD_COMMAND ${CMAKE_MAKE_PROGRAM} ${COMMON_ARGS} ${OPTIONAL_ARGS} - INSTALL_COMMAND ${CMAKE_MAKE_PROGRAM} install NO_SHARED=1 NO_LAPACK=1 PREFIX=<INSTALL_DIR> + INSTALL_COMMAND ${CMAKE_MAKE_PROGRAM} install NO_SHARED=1 NO_LAPACK=1 PREFIX=<INSTALL_DIR> + && rm -r ${CBLAS_INSTALL_DIR}/lib/cmake ${CBLAS_INSTALL_DIR}/lib/pkgconfig UPDATE_COMMAND "" CONFIGURE_COMMAND "" ) @@ -100,11 +101,6 @@ IF(NOT ${CBLAS_FOUND}) \"${CBLAS_INSTALL_DIR}/lib -> ${CMAKE_INSTALL_PREFIX}/${TMP_INSTALL_DIR}\" )" ) - INSTALL(CODE "execute_process( - COMMAND rm -r ${CMAKE_INSTALL_PREFIX}/${TMP_INSTALL_DIR}/cmake - ${CMAKE_INSTALL_PREFIX}/${TMP_INSTALL_DIR}/pkgconfig - )" - ) ENDIF() ENDIF(NOT ${CBLAS_FOUND}) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 356da582d1..d0b5eaec2e 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -186,7 +186,9 @@ function(cc_library TARGET_NAME) add_library(${TARGET_NAME} SHARED ${cc_library_SRCS}) else() add_library(${TARGET_NAME} STATIC ${cc_library_SRCS}) + find_fluid_modules(${TARGET_NAME}) endif() + if(cc_library_DEPS) # Don't need link libwarpctc.so if("${cc_library_DEPS};" MATCHES "warpctc;") @@ -263,7 +265,8 @@ function(nv_library TARGET_NAME) if (nv_library_SHARED OR nv_library_shared) # build *.so cuda_add_library(${TARGET_NAME} SHARED ${nv_library_SRCS}) else() - cuda_add_library(${TARGET_NAME} STATIC ${nv_library_SRCS}) + cuda_add_library(${TARGET_NAME} STATIC ${nv_library_SRCS}) + find_fluid_modules(${TARGET_NAME}) endif() if (nv_library_DEPS) add_dependencies(${TARGET_NAME} ${nv_library_DEPS}) diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index 4471df36b0..6b2237b858 100644 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -1,9 +1,22 @@ +set_property(GLOBAL PROPERTY FLUID_MODULES "") +# find all fluid modules is used for paddle fluid static library +function(find_fluid_modules TARGET_NAME) + get_filename_component(__target_path ${TARGET_NAME} ABSOLUTE) + string(FIND "${__target_path}" "fluid" pos) + if(pos GREATER 1) + get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES) + set(fluid_modules ${fluid_modules} ${TARGET_NAME}) + set_property(GLOBAL PROPERTY FLUID_MODULES "${fluid_modules}") + endif() +endfunction(find_fluid_modules) + # make package for paddle fluid shared and static library function(copy TARGET) set(options "") set(oneValueArgs "") set(multiValueArgs SRCS DSTS DEPS) cmake_parse_arguments(copy_lib "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + set(inference_lib_dist_dep ${TARGET} ${inference_lib_dist_dep} PARENT_SCOPE) list(LENGTH copy_lib_SRCS copy_lib_SRCS_len) list(LENGTH copy_lib_DSTS copy_lib_DSTS_len) @@ -42,13 +55,21 @@ copy(glog_lib DSTS ${dst_dir} ${dst_dir}/lib ) -IF(NOT PROTOBUF_FOUND) +if(NOT PROTOBUF_FOUND) set(dst_dir "${CMAKE_INSTALL_PREFIX}/third_party/install/protobuf") copy(protobuf_lib - SRCS ${PROTOBUF_INCLUDE_DIR} ${PROTOBUF_LITE_LIBRARY} + SRCS ${PROTOBUF_INCLUDE_DIR} ${PROTOBUF_LIBRARY} DSTS ${dst_dir} ${dst_dir}/lib ) -ENDIF(NOT PROTOBUF_FOUND) +endif() + +if(NOT CBLAS_FOUND) + set(dst_dir "${CMAKE_INSTALL_PREFIX}/third_party/install/openblas") + copy(openblas_lib + SRCS ${CBLAS_INSTALL_DIR}/lib ${CBLAS_INSTALL_DIR}/include + DSTS ${dst_dir} ${dst_dir} + ) +endif() # paddle fluid module set(src_dir "${PADDLE_SOURCE_DIR}/paddle/fluid") @@ -66,8 +87,8 @@ copy(memory_lib ) set(module "inference") -copy(inference_lib DEPENDS paddle_fluid_shared - SRCS ${src_dir}/${module}/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/inference/libpaddle_fluid.so +copy(inference_lib DEPS paddle_fluid_shared paddle_fluid + SRCS ${src_dir}/${module}/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/inference/libpaddle_fluid.* DSTS ${dst_dir}/${module} ${dst_dir}/${module} ) @@ -83,6 +104,4 @@ copy(string_lib DSTS ${dst_dir}/${module} ${dst_dir}/${module}/tinyformat ) -add_custom_target(inference_lib_dist DEPENDS - inference_lib framework_lib memory_lib platform_lib string_lib - gflags_lib glog_lib protobuf_lib eigen3_lib) +add_custom_target(inference_lib_dist DEPENDS ${inference_lib_dist_dep}) diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index bdb147955c..17ccca8cdc 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -5,7 +5,8 @@ cc_library(paddle_fluid_api DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB}) # Create static library -cc_library(paddle_fluid DEPS paddle_fluid_api ${FLUID_CORE_MODULES} ${GLOB_OP_LIB}) +get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES) +cc_library(paddle_fluid DEPS ${fluid_modules}) # Create shared library cc_library(paddle_fluid_shared SHARED diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index 06319fc638..6be2bd8fad 100644 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -213,7 +213,7 @@ function gen_fluid_inference_lib() { if [ ${WITH_C_API:-OFF} == "OFF" ] ; then cat <<EOF ======================================== - Building fluid inference library ... + Deploying fluid inference library ... ======================================== EOF make inference_lib_dist From 7f00716c87aea534a8e090833595ba1eff9b7e41 Mon Sep 17 00:00:00 2001 From: kexinzhao <kexin.zhao.paddle@gmail.com> Date: Wed, 7 Mar 2018 15:19:35 -0800 Subject: [PATCH 19/40] Add context wait in type_transform (#8850) --- paddle/fluid/framework/data_type_transform.cc | 1 + .../framework/data_type_transform_test.cc | 24 +++++++------- .../framework/data_type_transform_test.cu | 33 +++++++++++-------- 3 files changed, 33 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/framework/data_type_transform.cc b/paddle/fluid/framework/data_type_transform.cc index 554cd58916..c0523f3c79 100644 --- a/paddle/fluid/framework/data_type_transform.cc +++ b/paddle/fluid/framework/data_type_transform.cc @@ -53,6 +53,7 @@ struct CastDataType { auto* context = static_cast<const platform::CUDADeviceContext*>(ctx_); trans(*context, in_begin, in_end, out_begin, CastDataTypeFunctor<InType, OutType>()); + context->Wait(); #endif } else { PADDLE_THROW("Unsupported place!"); diff --git a/paddle/fluid/framework/data_type_transform_test.cc b/paddle/fluid/framework/data_type_transform_test.cc index c992cba9a3..6b9a8f5e28 100644 --- a/paddle/fluid/framework/data_type_transform_test.cc +++ b/paddle/fluid/framework/data_type_transform_test.cc @@ -50,13 +50,13 @@ TEST(DataTypeTransform, CPUTransform) { TransDataType(kernel_fp32, kernel_fp64, in, &out); double* out_data_double = out.data<double>(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_double[i], static_cast<double>(i / 3)); + EXPECT_EQ(out_data_double[i], static_cast<double>(i / 3)); } TransDataType(kernel_fp32, kernel_int32, in, &out); int* out_data_int = out.data<int>(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_int[i], static_cast<int>(i / 3)); + EXPECT_EQ(out_data_int[i], static_cast<int>(i / 3)); } } @@ -76,31 +76,31 @@ TEST(DataTypeTransform, CPUTransform) { TransDataType(kernel_fp16, kernel_fp32, in, &out); float* out_data_float = out.data<float>(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_float[i], static_cast<float>(ptr[i])); + EXPECT_EQ(out_data_float[i], static_cast<float>(ptr[i])); } TransDataType(kernel_fp16, kernel_fp64, in, &out); double* out_data_double = out.data<double>(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_double[i], static_cast<double>(ptr[i])); + EXPECT_EQ(out_data_double[i], static_cast<double>(ptr[i])); } TransDataType(kernel_fp16, kernel_int32, in, &out); int* out_data_int = out.data<int>(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_int[i], static_cast<int>(ptr[i])); + EXPECT_EQ(out_data_int[i], static_cast<int>(ptr[i])); } TransDataType(kernel_fp16, kernel_int64, in, &out); int64_t* out_data_int64 = out.data<int64_t>(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_int64[i], static_cast<int64_t>(ptr[i])); + EXPECT_EQ(out_data_int64[i], static_cast<int64_t>(ptr[i])); } TransDataType(kernel_fp16, kernel_bool, in, &out); bool* out_data_bool = out.data<bool>(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_bool[i], static_cast<bool>(ptr[i])); + EXPECT_EQ(out_data_bool[i], static_cast<bool>(ptr[i])); } // transform float to float16 @@ -112,7 +112,7 @@ TEST(DataTypeTransform, CPUTransform) { TransDataType(kernel_fp32, kernel_fp16, in, &out); ptr = out.data<float16>(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_float[i]).x); + EXPECT_EQ(ptr[i].x, static_cast<float16>(in_data_float[i]).x); } // transform double to float16 @@ -124,7 +124,7 @@ TEST(DataTypeTransform, CPUTransform) { TransDataType(kernel_fp64, kernel_fp16, in, &out); ptr = out.data<float16>(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_double[i]).x); + EXPECT_EQ(ptr[i].x, static_cast<float16>(in_data_double[i]).x); } // transform int to float16 @@ -136,7 +136,7 @@ TEST(DataTypeTransform, CPUTransform) { TransDataType(kernel_int32, kernel_fp16, in, &out); ptr = out.data<float16>(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_int[i]).x); + EXPECT_EQ(ptr[i].x, static_cast<float16>(in_data_int[i]).x); } // transform int64 to float16 @@ -148,7 +148,7 @@ TEST(DataTypeTransform, CPUTransform) { TransDataType(kernel_int64, kernel_fp16, in, &out); ptr = out.data<float16>(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_int64[i]).x); + EXPECT_EQ(ptr[i].x, static_cast<float16>(in_data_int64[i]).x); } // transform bool to float16 @@ -160,7 +160,7 @@ TEST(DataTypeTransform, CPUTransform) { TransDataType(kernel_bool, kernel_fp16, in, &out); ptr = out.data<float16>(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_bool[i]).x); + EXPECT_EQ(ptr[i].x, static_cast<float16>(in_data_bool[i]).x); } } } diff --git a/paddle/fluid/framework/data_type_transform_test.cu b/paddle/fluid/framework/data_type_transform_test.cu index 3939bc5e75..de389ddabc 100644 --- a/paddle/fluid/framework/data_type_transform_test.cu +++ b/paddle/fluid/framework/data_type_transform_test.cu @@ -49,15 +49,16 @@ TEST(DataTypeTransform, GPUTransform) { float arr[6] = {0, 1, 2, 3, 4, 5}; int data_number = sizeof(arr) / sizeof(arr[0]); memcpy(in_ptr, arr, sizeof(arr)); - TensorCopy(in, gpu_place, context, &in_gpu); + TensorCopy(in, gpu_place, context, &in_gpu); + context.Wait(); TransDataType(kernel_fp32, kernel_fp64, in_gpu, &out_gpu); TensorCopy(out_gpu, cpu_place, context, &out); context.Wait(); double* out_data_double = out.data<double>(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_double[i], static_cast<double>(arr[i])); + EXPECT_EQ(out_data_double[i], static_cast<double>(arr[i])); } TransDataType(kernel_fp32, kernel_int32, in_gpu, &out_gpu); @@ -66,7 +67,7 @@ TEST(DataTypeTransform, GPUTransform) { int* out_data_int = out.data<int>(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_int[i], static_cast<int>(arr[i])); + EXPECT_EQ(out_data_int[i], static_cast<int>(arr[i])); } } @@ -83,6 +84,7 @@ TEST(DataTypeTransform, GPUTransform) { int data_number = sizeof(arr) / sizeof(arr[0]); memcpy(ptr, arr, sizeof(arr)); TensorCopy(in, gpu_place, context, &in_gpu); + context.Wait(); // transform from float16 to other data types TransDataType(kernel_fp16, kernel_fp32, in_gpu, &out_gpu); @@ -91,7 +93,7 @@ TEST(DataTypeTransform, GPUTransform) { float* out_data_float = out.data<float>(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_float[i], static_cast<float>(ptr[i])); + EXPECT_EQ(out_data_float[i], static_cast<float>(ptr[i])); } TransDataType(kernel_fp16, kernel_fp64, in_gpu, &out_gpu); @@ -100,7 +102,7 @@ TEST(DataTypeTransform, GPUTransform) { double* out_data_double = out.data<double>(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_double[i], static_cast<double>(ptr[i])); + EXPECT_EQ(out_data_double[i], static_cast<double>(ptr[i])); } TransDataType(kernel_fp16, kernel_int32, in_gpu, &out_gpu); @@ -109,7 +111,7 @@ TEST(DataTypeTransform, GPUTransform) { int* out_data_int = out.data<int>(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_int[i], static_cast<int>(ptr[i])); + EXPECT_EQ(out_data_int[i], static_cast<int>(ptr[i])); } TransDataType(kernel_fp16, kernel_int64, in_gpu, &out_gpu); @@ -118,7 +120,7 @@ TEST(DataTypeTransform, GPUTransform) { int64_t* out_data_int64 = out.data<int64_t>(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_int64[i], static_cast<int64_t>(ptr[i])); + EXPECT_EQ(out_data_int64[i], static_cast<int64_t>(ptr[i])); } TransDataType(kernel_fp16, kernel_bool, in_gpu, &out_gpu); @@ -127,7 +129,7 @@ TEST(DataTypeTransform, GPUTransform) { bool* out_data_bool = out.data<bool>(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(out_data_bool[i], static_cast<bool>(ptr[i])); + EXPECT_EQ(out_data_bool[i], static_cast<bool>(ptr[i])); } // transform float to float16 @@ -137,13 +139,14 @@ TEST(DataTypeTransform, GPUTransform) { } TensorCopy(in, gpu_place, context, &in_gpu); + context.Wait(); TransDataType(kernel_fp32, kernel_fp16, in_gpu, &out_gpu); TensorCopy(out_gpu, cpu_place, context, &out); context.Wait(); ptr = out.data<float16>(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_float[i]).x); + EXPECT_EQ(ptr[i].x, static_cast<float16>(in_data_float[i]).x); } // transform double to float16 @@ -154,13 +157,14 @@ TEST(DataTypeTransform, GPUTransform) { } TensorCopy(in, gpu_place, context, &in_gpu); + context.Wait(); TransDataType(kernel_fp64, kernel_fp16, in_gpu, &out_gpu); TensorCopy(out_gpu, cpu_place, context, &out); context.Wait(); ptr = out.data<float16>(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_double[i]).x); + EXPECT_EQ(ptr[i].x, static_cast<float16>(in_data_double[i]).x); } // transform int to float16 @@ -170,13 +174,14 @@ TEST(DataTypeTransform, GPUTransform) { } TensorCopy(in, gpu_place, context, &in_gpu); + context.Wait(); TransDataType(kernel_int32, kernel_fp16, in_gpu, &out_gpu); TensorCopy(out_gpu, cpu_place, context, &out); context.Wait(); ptr = out.data<float16>(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_int[i]).x); + EXPECT_EQ(ptr[i].x, static_cast<float16>(in_data_int[i]).x); } // transform int64 to float16 @@ -187,13 +192,14 @@ TEST(DataTypeTransform, GPUTransform) { } TensorCopy(in, gpu_place, context, &in_gpu); + context.Wait(); TransDataType(kernel_int64, kernel_fp16, in_gpu, &out_gpu); TensorCopy(out_gpu, cpu_place, context, &out); context.Wait(); ptr = out.data<float16>(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_int64[i]).x); + EXPECT_EQ(ptr[i].x, static_cast<float16>(in_data_int64[i]).x); } // transform bool to float16 @@ -203,13 +209,14 @@ TEST(DataTypeTransform, GPUTransform) { } TensorCopy(in, gpu_place, context, &in_gpu); + context.Wait(); TransDataType(kernel_bool, kernel_fp16, in_gpu, &out_gpu); TensorCopy(out_gpu, cpu_place, context, &out); context.Wait(); ptr = out.data<float16>(); for (int i = 0; i < data_number; ++i) { - ASSERT_EQ(ptr[i].x, static_cast<float16>(in_data_bool[i]).x); + EXPECT_EQ(ptr[i].x, static_cast<float16>(in_data_bool[i]).x); } } } From c74797a856160f55798b131bebe871b430e39627 Mon Sep 17 00:00:00 2001 From: qiaolongfei <qiaolongfei@baidu.com> Date: Thu, 8 Mar 2018 09:26:02 +0800 Subject: [PATCH 20/40] add warning --- doc/fluid/howto/optimization/timeline.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/fluid/howto/optimization/timeline.md b/doc/fluid/howto/optimization/timeline.md index f0c1f1002e..57b48a47fe 100644 --- a/doc/fluid/howto/optimization/timeline.md +++ b/doc/fluid/howto/optimization/timeline.md @@ -1,6 +1,6 @@ ## how to use timeline tool to do profile -1. Add `with profiler.profiler(...)` to the main training loop. After run, the code will generate a profile record file `/tmp/profile`. +1. Add `with profiler.profiler(...)` to the main training loop. After run, the code will generate a profile record file `/tmp/profile`. **Warning**: Please do not run too many batches when use profiler to record timeline infomation, for the profile record will grow with the batch number. ```python with profiler.profiler('All', 'total', '/tmp/profile') as prof: From 205cadf6b774e0854b1f68c5b7b44163f7ac1095 Mon Sep 17 00:00:00 2001 From: qiaolongfei <qiaolongfei@baidu.com> Date: Thu, 8 Mar 2018 09:27:07 +0800 Subject: [PATCH 21/40] typo --- doc/fluid/howto/optimization/timeline.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/fluid/howto/optimization/timeline.md b/doc/fluid/howto/optimization/timeline.md index 57b48a47fe..9d9565a3e6 100644 --- a/doc/fluid/howto/optimization/timeline.md +++ b/doc/fluid/howto/optimization/timeline.md @@ -1,6 +1,6 @@ ## how to use timeline tool to do profile -1. Add `with profiler.profiler(...)` to the main training loop. After run, the code will generate a profile record file `/tmp/profile`. **Warning**: Please do not run too many batches when use profiler to record timeline infomation, for the profile record will grow with the batch number. +1. Add `with profiler.profiler(...)` to the main training loop. After run, the code will generate a profile record file `/tmp/profile`. **Warning**: Please do not run too many batches when use profiler to record timeline information, for the profile record will grow with the batch number. ```python with profiler.profiler('All', 'total', '/tmp/profile') as prof: From ded34b2c0f648409c0dd970c2e1ff5efa4817091 Mon Sep 17 00:00:00 2001 From: qingqing01 <dangqingqing@baidu.com> Date: Thu, 8 Mar 2018 09:29:19 +0800 Subject: [PATCH 22/40] Fix detection_map_op for multi-device. (#8845) --- paddle/fluid/operators/detection_map_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/detection_map_op.cc b/paddle/fluid/operators/detection_map_op.cc index 9b8ca92537..73c84c2fe0 100644 --- a/paddle/fluid/operators/detection_map_op.cc +++ b/paddle/fluid/operators/detection_map_op.cc @@ -71,7 +71,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel { return framework::OpKernelType( framework::ToDataType( ctx.Input<framework::Tensor>("DetectRes")->type()), - ctx.device_context()); + platform::CPUPlace()); } }; From 9a6f0ab287ff73ed4f8b993793acbd2dc8b4572d Mon Sep 17 00:00:00 2001 From: Yancey1989 <yancey1989@gmail.com> Date: Thu, 8 Mar 2018 09:40:23 +0800 Subject: [PATCH 23/40] add capi-noavx-openblas download link --- doc/v2/build_and_install/pip_install_cn.rst | 2 +- doc/v2/build_and_install/pip_install_en.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/v2/build_and_install/pip_install_cn.rst b/doc/v2/build_and_install/pip_install_cn.rst index 8e4165da6b..ddcd42a0c6 100644 --- a/doc/v2/build_and_install/pip_install_cn.rst +++ b/doc/v2/build_and_install/pip_install_cn.rst @@ -39,7 +39,7 @@ PaddlePaddle可以使用常用的Python包管理工具 "cpu_avx_mkl", "`paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_CpuAvxCp27cp27mu/.lastSuccessful/paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl>`_", "`paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_CpuAvxCp27cp27mu/.lastSuccessful/paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl>`_", "`paddle.tgz <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_CpuAvxCp27cp27mu/.lastSuccessful/paddle.tgz>`_" "cpu_avx_openblas", "`paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_CpuAvxOpenblas/.lastSuccessful/paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl>`_", "`paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_CpuAvxOpenblas/.lastSuccessful/paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl>`_", "暂无" - "cpu_noavx_openblas", "`paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_CpuNoavxOpenblas/.lastSuccessful/paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl>`_", "`paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_CpuNoavxOpenblas/.lastSuccessful/paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl>`_", "暂无" + "cpu_noavx_openblas", "`paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_CpuNoavxOpenblas/.lastSuccessful/paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl>`_", "`paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_CpuNoavxOpenblas/.lastSuccessful/paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl>`_", "`paddle.tgz <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_CpuNoavxOpenblas/.lastSuccessful/paddle.tgz>`_" "cuda7.5_cudnn5_avx_mkl", "`paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda75cudnn5cp27cp27mu/.lastSuccessful/paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl>`_", "`paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda75cudnn5cp27cp27mu/.lastSuccessful/paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl>`_", "`paddle.tgz <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda75cudnn5cp27cp27mu/.lastSuccessful/paddle.tgz>`_" "cuda8.0_cudnn5_avx_mkl", "`paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda80cudnn5cp27cp27mu/.lastSuccessful/paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl>`_", "`paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda80cudnn5cp27cp27mu/.lastSuccessful/paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl>`_", "`paddle.tgz <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda80cudnn5cp27cp27mu/.lastSuccessful/paddle.tgz>`_" "cuda8.0_cudnn7_avx_mkl", "`paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda8cudnn7cp27cp27mu/.lastSuccessful/paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl>`_", "`paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda8cudnn7cp27cp27mu/.lastSuccessful/paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl>`_", "`paddle.tgz <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda8cudnn7cp27cp27mu/.lastSuccessful/paddle.tgz>`_" diff --git a/doc/v2/build_and_install/pip_install_en.rst b/doc/v2/build_and_install/pip_install_en.rst index 0d4c925b6e..e08c84703b 100644 --- a/doc/v2/build_and_install/pip_install_en.rst +++ b/doc/v2/build_and_install/pip_install_en.rst @@ -42,7 +42,7 @@ If the links below shows up the login form, just click "Log in as guest" to star "cpu_avx_mkl", "`paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_CpuAvxCp27cp27mu/.lastSuccessful/paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl>`_", "`paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_CpuAvxCp27cp27mu/.lastSuccessful/paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl>`_", "`paddle.tgz <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_CpuAvxCp27cp27mu/.lastSuccessful/paddle.tgz>`_" "cpu_avx_openblas", "`paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_CpuAvxOpenblas/.lastSuccessful/paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl>`_", "`paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_CpuAvxOpenblas/.lastSuccessful/paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl>`_", "Not Available" - "cpu_noavx_openblas", "`paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_CpuNoavxOpenblas/.lastSuccessful/paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl>`_", "`paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_CpuNoavxOpenblas/.lastSuccessful/paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl>`_", "Not Available" + "cpu_noavx_openblas", "`paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_CpuNoavxOpenblas/.lastSuccessful/paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl>`_", "`paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_CpuNoavxOpenblas/.lastSuccessful/paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl>`_", "`paddle.tgz <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_CpuNoavxOpenblas/.lastSuccessful/paddle.tgz>`_" "cuda7.5_cudnn5_avx_mkl", "`paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda75cudnn5cp27cp27mu/.lastSuccessful/paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl>`_", "`paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda75cudnn5cp27cp27mu/.lastSuccessful/paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl>`_", "`paddle.tgz <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda75cudnn5cp27cp27mu/.lastSuccessful/paddle.tgz>`_" "cuda8.0_cudnn5_avx_mkl", "`paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda80cudnn5cp27cp27mu/.lastSuccessful/paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl>`_", "`paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda80cudnn5cp27cp27mu/.lastSuccessful/paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl>`_", "`paddle.tgz <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda80cudnn5cp27cp27mu/.lastSuccessful/paddle.tgz>`_" "cuda8.0_cudnn7_avx_mkl", "`paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda8cudnn7cp27cp27mu/.lastSuccessful/paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl>`_", "`paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda8cudnn7cp27cp27mu/.lastSuccessful/paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl>`_", "`paddle.tgz <https://guest:@paddleci.ngrok.io/repository/download/Manylinux1_Cuda8cudnn7cp27cp27mu/.lastSuccessful/paddle.tgz>`_" From a032f56f7cdb12a9a62f14f6619e96a2a04b631c Mon Sep 17 00:00:00 2001 From: Yiqun Liu <liuyiqun01@baidu.com> Date: Thu, 8 Mar 2018 09:48:53 +0800 Subject: [PATCH 24/40] Add profiling information for inference example (#8748) * Add profiling information for inference example, recognize digits. * Refine the profiling method. * Correct the use of RecordEvent and simplify recognize_digits. --- paddle/fluid/inference/io.cc | 15 ++-- .../test_inference_image_classification.cc | 19 +++-- .../book/test_inference_recognize_digits.cc | 83 ++++++------------- paddle/fluid/inference/tests/test_helper.h | 76 +++++++++++++---- paddle/fluid/platform/profiler.cc | 2 +- 5 files changed, 104 insertions(+), 91 deletions(-) diff --git a/paddle/fluid/inference/io.cc b/paddle/fluid/inference/io.cc index 80eb988967..52e9c0baa6 100644 --- a/paddle/fluid/inference/io.cc +++ b/paddle/fluid/inference/io.cc @@ -22,14 +22,14 @@ namespace paddle { namespace inference { void ReadBinaryFile(const std::string& filename, std::string& contents) { - VLOG(3) << "loading model from " << filename; - std::ifstream inputfs(filename, std::ios::in | std::ios::binary); - inputfs.seekg(0, std::ios::end); + std::ifstream fin(filename, std::ios::in | std::ios::binary); + PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s", filename); + fin.seekg(0, std::ios::end); contents.clear(); - contents.resize(inputfs.tellg()); - inputfs.seekg(0, std::ios::beg); - inputfs.read(&contents[0], contents.size()); - inputfs.close(); + contents.resize(fin.tellg()); + fin.seekg(0, std::ios::beg); + fin.read(&contents[0], contents.size()); + fin.close(); } bool IsPersistable(const framework::VarDesc* var) { @@ -97,6 +97,7 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor& executor, const std::string& dirname) { std::string model_filename = dirname + "/__model__"; std::string program_desc_str; + VLOG(3) << "loading model from " << model_filename; ReadBinaryFile(model_filename, program_desc_str); std::unique_ptr<framework::ProgramDesc> main_program( diff --git a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc index d6fc51301b..e9a27171f1 100644 --- a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc +++ b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc @@ -17,10 +17,13 @@ limitations under the License. */ #include "paddle/fluid/inference/tests/test_helper.h" DEFINE_string(dirname, "", "Directory of the inference model."); +DEFINE_int32(batch_size, 1, "Batch size of input data"); +DEFINE_int32(repeat, 1, "Running the inference program repeat times"); TEST(inference, image_classification) { - if (FLAGS_dirname.empty()) { - LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model"; + if (FLAGS_dirname.empty() || FLAGS_batch_size < 1 || FLAGS_repeat < 1) { + LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model " + "--batch_size=1 --repeat=1"; } LOG(INFO) << "FLAGS_dirname: " << FLAGS_dirname << std::endl; @@ -29,13 +32,11 @@ TEST(inference, image_classification) { // 0. Call `paddle::framework::InitDevices()` initialize all the devices // In unittests, this is done in paddle/testing/paddle_gtest_main.cc - int64_t batch_size = 1; - paddle::framework::LoDTensor input; // Use normilized image pixels as input data, // which should be in the range [0.0, 1.0]. SetupTensor<float>(input, - {batch_size, 3, 32, 32}, + {FLAGS_batch_size, 3, 32, 32}, static_cast<float>(0), static_cast<float>(1)); std::vector<paddle::framework::LoDTensor*> cpu_feeds; @@ -46,7 +47,9 @@ TEST(inference, image_classification) { cpu_fetchs1.push_back(&output1); // Run inference on CPU - TestInference<paddle::platform::CPUPlace>(dirname, cpu_feeds, cpu_fetchs1); + LOG(INFO) << "--- CPU Runs: ---"; + TestInference<paddle::platform::CPUPlace>( + dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat); LOG(INFO) << output1.dims(); #ifdef PADDLE_WITH_CUDA @@ -55,7 +58,9 @@ TEST(inference, image_classification) { cpu_fetchs2.push_back(&output2); // Run inference on CUDA GPU - TestInference<paddle::platform::CUDAPlace>(dirname, cpu_feeds, cpu_fetchs2); + LOG(INFO) << "--- GPU Runs: ---"; + TestInference<paddle::platform::CUDAPlace>( + dirname, cpu_feeds, cpu_fetchs2, FLAGS_repeat); LOG(INFO) << output2.dims(); CheckError<float>(output1, output2); diff --git a/paddle/fluid/inference/tests/book/test_inference_recognize_digits.cc b/paddle/fluid/inference/tests/book/test_inference_recognize_digits.cc index 99bee94cb8..1fb0f9e777 100644 --- a/paddle/fluid/inference/tests/book/test_inference_recognize_digits.cc +++ b/paddle/fluid/inference/tests/book/test_inference_recognize_digits.cc @@ -17,10 +17,13 @@ limitations under the License. */ #include "paddle/fluid/inference/tests/test_helper.h" DEFINE_string(dirname, "", "Directory of the inference model."); +DEFINE_int32(batch_size, 1, "Batch size of input data"); +DEFINE_int32(repeat, 1, "Running the inference program repeat times"); TEST(inference, recognize_digits) { - if (FLAGS_dirname.empty()) { - LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model"; + if (FLAGS_dirname.empty() || FLAGS_batch_size < 1 || FLAGS_repeat < 1) { + LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model " + "--batch_size=1 --repeat=1"; } LOG(INFO) << "FLAGS_dirname: " << FLAGS_dirname << std::endl; @@ -29,77 +32,39 @@ TEST(inference, recognize_digits) { // 0. Call `paddle::framework::InitDevices()` initialize all the devices // In unittests, this is done in paddle/testing/paddle_gtest_main.cc - int64_t batch_size = 1; - paddle::framework::LoDTensor input; // Use normilized image pixels as input data, // which should be in the range [-1.0, 1.0]. SetupTensor<float>(input, - {batch_size, 1, 28, 28}, + {FLAGS_batch_size, 1, 28, 28}, static_cast<float>(-1), static_cast<float>(1)); std::vector<paddle::framework::LoDTensor*> cpu_feeds; cpu_feeds.push_back(&input); - paddle::framework::LoDTensor output1; - std::vector<paddle::framework::LoDTensor*> cpu_fetchs1; - cpu_fetchs1.push_back(&output1); + for (auto is_combined : {false, true}) { + paddle::framework::LoDTensor output1; + std::vector<paddle::framework::LoDTensor*> cpu_fetchs1; + cpu_fetchs1.push_back(&output1); - // Run inference on CPU - TestInference<paddle::platform::CPUPlace>(dirname, cpu_feeds, cpu_fetchs1); - LOG(INFO) << output1.dims(); + // Run inference on CPU + LOG(INFO) << "--- CPU Runs: is_combined=" << is_combined << " ---"; + TestInference<paddle::platform::CPUPlace>( + dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat, is_combined); + LOG(INFO) << output1.dims(); #ifdef PADDLE_WITH_CUDA - paddle::framework::LoDTensor output2; - std::vector<paddle::framework::LoDTensor*> cpu_fetchs2; - cpu_fetchs2.push_back(&output2); + paddle::framework::LoDTensor output2; + std::vector<paddle::framework::LoDTensor*> cpu_fetchs2; + cpu_fetchs2.push_back(&output2); - // Run inference on CUDA GPU - TestInference<paddle::platform::CUDAPlace>(dirname, cpu_feeds, cpu_fetchs2); - LOG(INFO) << output2.dims(); + // Run inference on CUDA GPU + LOG(INFO) << "--- GPU Runs: is_combined=" << is_combined << " ---"; + TestInference<paddle::platform::CUDAPlace>( + dirname, cpu_feeds, cpu_fetchs2, FLAGS_repeat, is_combined); + LOG(INFO) << output2.dims(); - CheckError<float>(output1, output2); + CheckError<float>(output1, output2); #endif -} - -TEST(inference, recognize_digits_combine) { - if (FLAGS_dirname.empty()) { - LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model"; } - - LOG(INFO) << "FLAGS_dirname: " << FLAGS_dirname << std::endl; - std::string dirname = FLAGS_dirname; - - // 0. Call `paddle::framework::InitDevices()` initialize all the devices - // In unittests, this is done in paddle/testing/paddle_gtest_main.cc - - paddle::framework::LoDTensor input; - // Use normilized image pixels as input data, - // which should be in the range [-1.0, 1.0]. - SetupTensor<float>( - input, {1, 1, 28, 28}, static_cast<float>(-1), static_cast<float>(1)); - std::vector<paddle::framework::LoDTensor*> cpu_feeds; - cpu_feeds.push_back(&input); - - paddle::framework::LoDTensor output1; - std::vector<paddle::framework::LoDTensor*> cpu_fetchs1; - cpu_fetchs1.push_back(&output1); - - // Run inference on CPU - TestInference<paddle::platform::CPUPlace, true>( - dirname, cpu_feeds, cpu_fetchs1); - LOG(INFO) << output1.dims(); - -#ifdef PADDLE_WITH_CUDA - paddle::framework::LoDTensor output2; - std::vector<paddle::framework::LoDTensor*> cpu_fetchs2; - cpu_fetchs2.push_back(&output2); - - // Run inference on CUDA GPU - TestInference<paddle::platform::CUDAPlace, true>( - dirname, cpu_feeds, cpu_fetchs2); - LOG(INFO) << output2.dims(); - - CheckError<float>(output1, output2); -#endif } diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h index 49518e50d8..d0688445fe 100644 --- a/paddle/fluid/inference/tests/test_helper.h +++ b/paddle/fluid/inference/tests/test_helper.h @@ -15,6 +15,7 @@ limitations under the License. */ #include <time.h> #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/inference/io.h" +#include "paddle/fluid/platform/profiler.h" template <typename T> void SetupTensor(paddle::framework::LoDTensor& input, @@ -87,31 +88,58 @@ void CheckError(paddle::framework::LoDTensor& output1, EXPECT_EQ(count, 0U) << "There are " << count << " different elements."; } -template <typename Place, bool IsCombined = false> +template <typename Place> void TestInference(const std::string& dirname, const std::vector<paddle::framework::LoDTensor*>& cpu_feeds, - std::vector<paddle::framework::LoDTensor*>& cpu_fetchs) { + std::vector<paddle::framework::LoDTensor*>& cpu_fetchs, + const int repeat = 1, + const bool is_combined = false) { // 1. Define place, executor, scope auto place = Place(); auto executor = paddle::framework::Executor(place); auto* scope = new paddle::framework::Scope(); + // Profile the performance + paddle::platform::ProfilerState state; + if (paddle::platform::is_cpu_place(place)) { + state = paddle::platform::ProfilerState::kCPU; + } else { +#ifdef PADDLE_WITH_CUDA + state = paddle::platform::ProfilerState::kCUDA; + // The default device_id of paddle::platform::CUDAPlace is 0. + // Users can get the device_id using: + // int device_id = place.GetDeviceId(); + paddle::platform::SetDeviceId(0); +#endif + } + + // Enable the profiler + paddle::platform::EnableProfiler(state); + // 2. Initialize the inference_program and load parameters std::unique_ptr<paddle::framework::ProgramDesc> inference_program; - if (IsCombined) { - // All parameters are saved in a single file. - // Hard-coding the file names of program and parameters in unittest. - // The file names should be consistent with that used in Python API - // `fluid.io.save_inference_model`. - std::string prog_filename = "__model_combined__"; - std::string param_filename = "__params_combined__"; - inference_program = paddle::inference::Load(executor, - *scope, - dirname + "/" + prog_filename, - dirname + "/" + param_filename); - } else { - // Parameters are saved in separate files sited in the specified `dirname`. - inference_program = paddle::inference::Load(executor, *scope, dirname); + { + paddle::platform::RecordEvent record_event( + "init_program", + paddle::platform::DeviceContextPool::Instance().Get(place)); + + if (is_combined) { + // All parameters are saved in a single file. + // Hard-coding the file names of program and parameters in unittest. + // The file names should be consistent with that used in Python API + // `fluid.io.save_inference_model`. + std::string prog_filename = "__model_combined__"; + std::string param_filename = "__params_combined__"; + inference_program = + paddle::inference::Load(executor, + *scope, + dirname + "/" + prog_filename, + dirname + "/" + param_filename); + } else { + // Parameters are saved in separate files sited in the specified + // `dirname`. + inference_program = paddle::inference::Load(executor, *scope, dirname); + } } // 3. Get the feed_target_names and fetch_target_names @@ -134,7 +162,21 @@ void TestInference(const std::string& dirname, } // 6. Run the inference program - executor.Run(*inference_program, scope, feed_targets, fetch_targets); + { + // Run repeat times to profile the performance + for (int i = 0; i < repeat; ++i) { + paddle::platform::RecordEvent record_event( + "run_inference", + paddle::platform::DeviceContextPool::Instance().Get(place)); + + executor.Run(*inference_program, scope, feed_targets, fetch_targets); + } + } + + // Disable the profiler and print the timing information + paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault, + "profiler.txt"); + paddle::platform::ResetProfiler(); delete scope; } diff --git a/paddle/fluid/platform/profiler.cc b/paddle/fluid/platform/profiler.cc index 094f9224f7..28ef3e04b1 100644 --- a/paddle/fluid/platform/profiler.cc +++ b/paddle/fluid/platform/profiler.cc @@ -178,7 +178,7 @@ void EnableProfiler(ProfilerState state) { } #ifdef PADDLE_WITH_CUDA if (g_state == ProfilerState::kCUDA) { - // Generate some dummy evenets first to reduce the startup overhead. + // Generate some dummy events first to reduce the startup overhead. for (int i = 0; i < 5; i++) { ForEachDevice([](int d) { DeviceContext* dev_ctx = new CUDADeviceContext(CUDAPlace(d)); From eb4684531307f1f1278a1e0869c8ebaf8ced4dc0 Mon Sep 17 00:00:00 2001 From: Xin Pan <panxin.grad@gmail.com> Date: Wed, 7 Mar 2018 18:07:42 -0800 Subject: [PATCH 25/40] Add warning --- paddle/fluid/platform/device_tracer.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/fluid/platform/device_tracer.cc b/paddle/fluid/platform/device_tracer.cc index 6efe703e22..8e1691efb5 100644 --- a/paddle/fluid/platform/device_tracer.cc +++ b/paddle/fluid/platform/device_tracer.cc @@ -201,6 +201,7 @@ class DeviceTracerImpl : public DeviceTracer { uint32_t correlation_id, uint64_t bytes) { // 0 means timestamp information could not be collected for the kernel. if (start_ns == 0 || end_ns == 0) { + LOG(WARNING) << name << " cannot be traced"; return; } std::lock_guard<std::mutex> l(trace_mu_); @@ -212,6 +213,7 @@ class DeviceTracerImpl : public DeviceTracer { uint32_t stream_id, uint32_t correlation_id) { // 0 means timestamp information could not be collected for the kernel. if (start == 0 || end == 0) { + LOG(WARNING) << correlation_id << " cannot be traced"; return; } std::lock_guard<std::mutex> l(trace_mu_); From f7c71356732a02eb1dd0a3596c12270b28e5e89e Mon Sep 17 00:00:00 2001 From: chengduoZH <zhaochengduo@163.com> Date: Thu, 8 Mar 2018 10:35:32 +0800 Subject: [PATCH 26/40] Add log before op Run --- paddle/fluid/framework/executor.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 961e3e22f2..f8e7d0d990 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -125,8 +125,9 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, for (auto& op_desc : block.AllOps()) { auto op = paddle::framework::OpRegistry::CreateOp(*op_desc); - VLOG(3) << place_ << " " << op->DebugStringEx(local_scope); + VLOG(4) << place_ << " " << op->DebugStringEx(local_scope); op->Run(*local_scope, place_); + VLOG(3) << place_ << " " << op->DebugStringEx(local_scope); if (FLAGS_benchmark) { VLOG(2) << "Memory used after operator " + op->Type() + " running: " From 47ca1814f3e1d81f6a01b18d94f4267c54123e9b Mon Sep 17 00:00:00 2001 From: QI JUN <qijun1994@hotmail.com> Date: Thu, 8 Mar 2018 10:44:30 +0800 Subject: [PATCH 27/40] fix mac build error (#8856) --- paddle/fluid/inference/tests/test_helper.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h index d0688445fe..0f5fe6d0aa 100644 --- a/paddle/fluid/inference/tests/test_helper.h +++ b/paddle/fluid/inference/tests/test_helper.h @@ -110,6 +110,8 @@ void TestInference(const std::string& dirname, // Users can get the device_id using: // int device_id = place.GetDeviceId(); paddle::platform::SetDeviceId(0); +#else + PADDLE_THROW("'CUDAPlace' is not supported in CPU only device."); #endif } From 30e556d675fa958783af6a5e31fa78616dc20c77 Mon Sep 17 00:00:00 2001 From: Xin Pan <panxin.grad@gmail.com> Date: Wed, 7 Mar 2018 19:09:45 -0800 Subject: [PATCH 28/40] Use vlog instead. --- paddle/fluid/platform/device_tracer.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/platform/device_tracer.cc b/paddle/fluid/platform/device_tracer.cc index 8e1691efb5..c4c0963091 100644 --- a/paddle/fluid/platform/device_tracer.cc +++ b/paddle/fluid/platform/device_tracer.cc @@ -201,7 +201,7 @@ class DeviceTracerImpl : public DeviceTracer { uint32_t correlation_id, uint64_t bytes) { // 0 means timestamp information could not be collected for the kernel. if (start_ns == 0 || end_ns == 0) { - LOG(WARNING) << name << " cannot be traced"; + VLOG(3) << name << " cannot be traced"; return; } std::lock_guard<std::mutex> l(trace_mu_); @@ -213,7 +213,7 @@ class DeviceTracerImpl : public DeviceTracer { uint32_t stream_id, uint32_t correlation_id) { // 0 means timestamp information could not be collected for the kernel. if (start == 0 || end == 0) { - LOG(WARNING) << correlation_id << " cannot be traced"; + VLOG(3) << correlation_id << " cannot be traced"; return; } std::lock_guard<std::mutex> l(trace_mu_); From fecc9a38c61ea707d28230de6f009d446fbac152 Mon Sep 17 00:00:00 2001 From: Yiqun Liu <liuyiqun01@baidu.com> Date: Thu, 8 Mar 2018 14:36:27 +0800 Subject: [PATCH 29/40] Add test for nested RecordEvent. (#8773) * Add test for nested RecordEvent. * Remove the debug information. * Add log information for the 3 usages and reduce the loop counts of nested case. --- paddle/fluid/platform/profiler_test.cc | 30 ++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/paddle/fluid/platform/profiler_test.cc b/paddle/fluid/platform/profiler_test.cc index 4a86d8ec62..fc77e0f321 100644 --- a/paddle/fluid/platform/profiler_test.cc +++ b/paddle/fluid/platform/profiler_test.cc @@ -75,6 +75,7 @@ TEST(RecordEvent, RecordEvent) { * ... * PopEvent(evt_name, dev_ctx); */ + LOG(INFO) << "Usage 1: PushEvent & PopEvent"; for (int loop = 0; loop < 3; ++loop) { for (int i = 1; i < 5; ++i) { std::string name = "op_" + std::to_string(i); @@ -93,6 +94,7 @@ TEST(RecordEvent, RecordEvent) { * ... * } */ + LOG(INFO) << "Usage 2: RecordEvent"; for (int i = 1; i < 5; ++i) { std::string name = "evs_op_" + std::to_string(i); RecordEvent record_event(name, dev_ctx); @@ -100,6 +102,34 @@ TEST(RecordEvent, RecordEvent) { while (counter != i * 1000) counter++; } + /* Usage 3 + * { + * RecordEvent record_event(name1, dev_ctx); + * ... + * code to be analyzed + * ... + * { + * RecordEvent nested_record_event(name2, dev_ctx); + * ... + * code to be analyzed + * ... + * } + * } + */ + LOG(INFO) << "Usage 3: nested RecordEvent"; + for (int i = 1; i < 5; ++i) { + std::string name = "ano_evs_op_" + std::to_string(i); + RecordEvent record_event(name, dev_ctx); + int counter = 1; + while (counter != i * 100) counter++; + { + std::string nested_name = "nested_ano_evs_op_" + std::to_string(i); + RecordEvent nested_record_event(nested_name, dev_ctx); + int nested_counter = 1; + while (nested_counter != i * 100) nested_counter++; + } + } + // Bad Usage: PushEvent("event_without_pop", dev_ctx); PopEvent("event_without_push", dev_ctx); From ffda2c414dcf520d2dea0ac9fccd0478dc8e081f Mon Sep 17 00:00:00 2001 From: qingqing01 <dangqingqing@baidu.com> Date: Thu, 8 Mar 2018 16:27:26 +0800 Subject: [PATCH 30/40] Clipping bbox in the mAP evaluator calculation. (#8872) --- paddle/fluid/operators/detection_map_op.h | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/detection_map_op.h b/paddle/fluid/operators/detection_map_op.h index 637f8368f8..a009e9dfce 100644 --- a/paddle/fluid/operators/detection_map_op.h +++ b/paddle/fluid/operators/detection_map_op.h @@ -144,6 +144,15 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> { } } + inline void ClipBBox(const Box& bbox, Box* clipped_bbox) const { + T one = static_cast<T>(1.0); + T zero = static_cast<T>(0.0); + clipped_bbox->xmin = std::max(std::min(bbox.xmin, one), zero); + clipped_bbox->ymin = std::max(std::min(bbox.ymin, one), zero); + clipped_bbox->xmax = std::max(std::min(bbox.xmax, one), zero); + clipped_bbox->ymax = std::max(std::min(bbox.ymax, one), zero); + } + void GetBoxes(const framework::LoDTensor& input_label, const framework::LoDTensor& input_detect, std::vector<std::map<int, std::vector<Box>>>& gt_boxes, @@ -360,7 +369,9 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> { size_t max_idx = 0; auto score = pred_boxes[i].first; for (size_t j = 0; j < matched_bboxes.size(); ++j) { - T overlap = JaccardOverlap(pred_boxes[i].second, matched_bboxes[j]); + Box& pred_box = pred_boxes[i].second; + ClipBBox(pred_box, &pred_box); + T overlap = JaccardOverlap(pred_box, matched_bboxes[j]); if (overlap > max_overlap) { max_overlap = overlap; max_idx = j; From cc0b8053f3f729bc709b476a18b455126b1ab026 Mon Sep 17 00:00:00 2001 From: guosheng <guosheng@baidu.com> Date: Thu, 8 Mar 2018 19:35:35 +0800 Subject: [PATCH 31/40] Refine the guide of RNN in docs --- doc/v2/howto/rnn/index_cn.rst | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/doc/v2/howto/rnn/index_cn.rst b/doc/v2/howto/rnn/index_cn.rst index bcc8c2f46e..6b630ccaa4 100644 --- a/doc/v2/howto/rnn/index_cn.rst +++ b/doc/v2/howto/rnn/index_cn.rst @@ -1,10 +1,34 @@ RNN模型 =========== +循环神经网络(RNN)是对序列数据建模的重要工具。PaddlePaddle提供了灵活的接口以支持复杂循环神经网络的构建。 +这一部分将分以下章节详细介绍如何使用PaddlePaddle搭建循环神经网络。 .. toctree:: :maxdepth: 1 rnn_config_cn.rst + +本章节由浅入深的展示了使用PaddlePaddle搭建循环神经网络的全貌:首先以简单的循环神经网络(vanilla RNN)为例, +说明如何封装配置循环神经网络组件;然后更进一步的通过sequence to sequence模型,逐步讲解如何构建完整而复杂的循环神经网络模型。 + +.. toctree:: + :maxdepth: 1 + recurrent_group_cn.md + +Recurrent Group是PaddlePaddle中实现复杂循环神经网络的关键,本章节阐述了PaddlePaddle中Recurrent Group的相关概念和原理, +对Recurrent Group接口进行了详细说明。另外,对双层RNN(对应的输入为双层序列)及Recurrent Group在其中的使用进行了介绍。 + +.. toctree:: + :maxdepth: 1 + hierarchical_layer_cn.rst + +本章节对双层序列进行了解释说明,列出了PaddlePaddle中支持双层序列作为输入的Layer并对其使用进行了逐一介绍。 + +.. toctree:: + :maxdepth: 1 + hrnn_rnn_api_compare_cn.rst + +本章节以PaddlePaddle的双层RNN单元测试中的网络配置为示例,辅以效果相同的单层RNN网络配置作为对比,讲解了多种情况下双层RNN的使用。 From 65cfd32774c86e1377ea86a26a40b35a137665b8 Mon Sep 17 00:00:00 2001 From: Melobelle <564445201@qq.com> Date: Thu, 8 Mar 2018 20:17:43 +0800 Subject: [PATCH 32/40] Fix some path errors (#8851) --- doc/fluid/read_source.md | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/doc/fluid/read_source.md b/doc/fluid/read_source.md index edf46aff8c..bb6d4563f5 100644 --- a/doc/fluid/read_source.md +++ b/doc/fluid/read_source.md @@ -2,17 +2,17 @@ Examples: https://github.com/PaddlePaddle/Paddle/tree/develop/python/paddle/fluid/tests/book -Core: https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/framework +Core: https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/fluid/framework -Operator: https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/operators +Operator: https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/fluid/operators -Memory: https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/memory +Memory: https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/fluid/memory -Platform: https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/platform +Platform: https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/fluid/platform # Compile Time -The following **defines** the NN. The definition goes into this [protocol buffer](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/framework.proto). +The following **defines** the NN. The definition goes into this [protocol buffer](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/framework/framework.proto). ```python x = fluid.layers.data(name='x', shape=[13], dtype='float32') @@ -29,10 +29,10 @@ sgd_optimizer.minimize(avg_cost) - Variables: `x`, `y`, `y_predict`, `cost` and `avg_cost`. [Python](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/framework.py#) - Layers: `fluid.layers.data`, `fluid.layers.fc` and `fluid.layers.mean` are layers. [Python](https://github.com/PaddlePaddle/Paddle/tree/develop/python/paddle/fluid/layers) - Every Layer has one or more operators and variables/parameters - - All the operators are defined at [`paddle/operators/`](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/operators). Other worth-looking files: - - Base class: [`paddle/framework/operator.h`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/operator.h) - - Operator Registration: [`paddle/framework/op_registry.h`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/op_registry.h) - - Operator Lookup: [`paddle/framework/op_info.h`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/op_info.h) + - All the operators are defined at [`paddle/fluid/operators/`](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/fluid/operators). Other worth-looking files: + - Base class: [`paddle/fluid/framework/operator.h`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/framework/operator.h) + - Operator Registration: [`paddle/fluid/framework/op_registry.h`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/framework/op_registry.h) + - Operator Lookup: [`paddle/fluid/framework/op_info.h`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/framework/op_info.h) - Optimizer: `fluid.optimizer.SGD`. It does the following - Add backward operators. [[Python](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/backward.py)] - Add optimizer operators. [[Python](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/optimizer.py)] @@ -55,13 +55,13 @@ exe.run(fluid.default_main_program(), fetch_list=[avg_cost]) ``` -- Place: `place`. one of CPU, GPU or FPGA. [C++](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/place.h) - - The device handle are at [paddle/platform/device_context.h](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/device_context.h) -- Executor: `fluid.Executor(place)`. [[Python](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/executor.py), [C++](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/executor.cc)] +- Place: `place`. one of CPU, GPU or FPGA. [C++](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/platform/place.h) + - The device handle are at [paddle/fluid/platform/device_context.h](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/platform/device_context.h) +- Executor: `fluid.Executor(place)`. [[Python](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/fluid/executor.py), [C++](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/framework/executor.cc)] - Feeds the data: `feed=feeder.feed(data)` - Evaluates all the operators - Fetches the result: `fetch_list=[avg_cost]` - Other worth looking files: - - Scope: [paddle/framework/scope.h](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/scope.h). Where all the variables live - - Variable: [paddle/framework/variable.h](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/variable.h). Where all the data (most likely tensors) live - - Tensor: [paddle/framework/tensor.h](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/tensor.h). Where we allocate memory through [`paddle/memory/`](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/memory) + - Scope: [paddle/fluid/framework/scope.h](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/framework/scope.h). Where all the variables live + - Variable: [paddle/fluid/framework/variable.h](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/framework/variable.h). Where all the data (most likely tensors) live + - Tensor: [paddle/fluid/framework/tensor.h](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/framework/tensor.h). Where we allocate memory through [`paddle/fluid/memory/`](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/fluid/memory) From 71dd899369081bdb46f58d7501154cd5ea980ce0 Mon Sep 17 00:00:00 2001 From: guosheng <guosheng@baidu.com> Date: Thu, 8 Mar 2018 21:11:37 +0800 Subject: [PATCH 33/40] Refine the guide of RNN in docs by following comments --- doc/v2/howto/rnn/index_cn.rst | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/doc/v2/howto/rnn/index_cn.rst b/doc/v2/howto/rnn/index_cn.rst index 6b630ccaa4..2032fb9e29 100644 --- a/doc/v2/howto/rnn/index_cn.rst +++ b/doc/v2/howto/rnn/index_cn.rst @@ -1,34 +1,34 @@ RNN模型 =========== 循环神经网络(RNN)是对序列数据建模的重要工具。PaddlePaddle提供了灵活的接口以支持复杂循环神经网络的构建。 -这一部分将分以下章节详细介绍如何使用PaddlePaddle搭建循环神经网络。 +这里将分为以下四个部分详细介绍如何使用PaddlePaddle搭建循环神经网络。 + +第一部分由浅入深的展示了使用PaddlePaddle搭建循环神经网络的全貌:首先以简单的循环神经网络(vanilla RNN)为例, +说明如何封装配置循环神经网络组件;然后更进一步的通过序列到序列(sequence to sequence)模型,逐步讲解如何构建完整而复杂的循环神经网络模型。 .. toctree:: :maxdepth: 1 rnn_config_cn.rst -本章节由浅入深的展示了使用PaddlePaddle搭建循环神经网络的全貌:首先以简单的循环神经网络(vanilla RNN)为例, -说明如何封装配置循环神经网络组件;然后更进一步的通过sequence to sequence模型,逐步讲解如何构建完整而复杂的循环神经网络模型。 +Recurrent Group是PaddlePaddle中实现复杂循环神经网络的关键,第二部分阐述了PaddlePaddle中Recurrent Group的相关概念和原理, +对Recurrent Group接口进行了详细说明。另外,对双层RNN(对应的输入为双层序列)及Recurrent Group在其中的使用进行了介绍。 .. toctree:: :maxdepth: 1 recurrent_group_cn.md -Recurrent Group是PaddlePaddle中实现复杂循环神经网络的关键,本章节阐述了PaddlePaddle中Recurrent Group的相关概念和原理, -对Recurrent Group接口进行了详细说明。另外,对双层RNN(对应的输入为双层序列)及Recurrent Group在其中的使用进行了介绍。 +第三部分对双层序列进行了解释说明,列出了PaddlePaddle中支持双层序列作为输入的Layer,并对其使用进行了逐一介绍。 .. toctree:: :maxdepth: 1 hierarchical_layer_cn.rst -本章节对双层序列进行了解释说明,列出了PaddlePaddle中支持双层序列作为输入的Layer并对其使用进行了逐一介绍。 +第四部分以PaddlePaddle的双层RNN单元测试中的网络配置为示例,辅以效果相同的单层RNN网络配置作为对比,讲解了多种情况下双层RNN的使用。 .. toctree:: :maxdepth: 1 hrnn_rnn_api_compare_cn.rst - -本章节以PaddlePaddle的双层RNN单元测试中的网络配置为示例,辅以效果相同的单层RNN网络配置作为对比,讲解了多种情况下双层RNN的使用。 From 9416703d8e12e2ebdefd08d56eb1dc1a7eb27986 Mon Sep 17 00:00:00 2001 From: Luo Tao <luotao02@baidu.com> Date: Fri, 9 Mar 2018 10:55:57 +0800 Subject: [PATCH 34/40] fix document deployment --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 9dd5f48164..bf6a41d13c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -56,7 +56,7 @@ script: export DEPLOY_DOCS_SH=https://raw.githubusercontent.com/PaddlePaddle/PaddlePaddle.org/master/scripts/deploy/deploy_docs.sh export DOCS_DIR=`pwd` cd .. - curl $DEPLOY_DOCS_SH | bash -s $CONTENT_DEC_PASSWD $TRAVIS_BRANCH $DOCS_DIR $DOCS_DIR/build/doc/v2 + curl $DEPLOY_DOCS_SH | bash -s $CONTENT_DEC_PASSWD $TRAVIS_BRANCH $DOCS_DIR $DOCS_DIR/build/doc/ notifications: email: on_success: change From 9a27d3af233ec5d34382f2fee599fa55088c4688 Mon Sep 17 00:00:00 2001 From: Xin Pan <panxin.grad@gmail.com> Date: Thu, 8 Mar 2018 19:14:35 -0800 Subject: [PATCH 35/40] Print exception message from threads --- paddle/fluid/framework/threadpool.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/threadpool.h b/paddle/fluid/framework/threadpool.h index 3adc260caf..df51fb24a5 100644 --- a/paddle/fluid/framework/threadpool.h +++ b/paddle/fluid/framework/threadpool.h @@ -67,10 +67,10 @@ class ThreadPool { } catch (platform::EnforceNotMet ex) { return std::unique_ptr<platform::EnforceNotMet>( new platform::EnforceNotMet(ex)); - } catch (...) { - LOG(FATAL) - << "Unexpected exception is catched in thread pool. All " - "throwable exception in Fluid should be an EnforceNotMet."; + } catch (const std::exception& e) { + LOG(FATAL) << "Unexpected exception is catched in thread pool. All " + "throwable exception in Fluid should be an EnforceNotMet." + << e.what(); } return nullptr; }); From 45af8c1e99333d807c052277220b0fd01b2bd18a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=AD=A6=E6=AF=85?= <typhoonzero1986@gmail.com> Date: Fri, 9 Mar 2018 11:17:10 +0800 Subject: [PATCH 36/40] Performance/zero copy variable seriralization (#8839) --- paddle/fluid/framework/tensor_util.cc | 1 - paddle/fluid/operators/detail/CMakeLists.txt | 5 +- .../operators/detail/bytebuffer_stream.cc | 88 +++++++ .../operators/detail/bytebuffer_stream.h | 51 ++++ .../operators/detail/proto_encoder_helper.h | 147 +++++++++++ paddle/fluid/operators/detail/send_recv.proto | 26 +- .../operators/detail/sendrecvop_utils.cc | 243 +++++++++++++++++- .../fluid/operators/detail/sendrecvop_utils.h | 34 +++ paddle/fluid/operators/detail/test_serde.cc | 195 ++++++++++++++ 9 files changed, 786 insertions(+), 4 deletions(-) create mode 100644 paddle/fluid/operators/detail/bytebuffer_stream.cc create mode 100644 paddle/fluid/operators/detail/bytebuffer_stream.h create mode 100644 paddle/fluid/operators/detail/proto_encoder_helper.h create mode 100644 paddle/fluid/operators/detail/test_serde.cc diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 9b465b85b0..8b7533ce71 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -187,7 +187,6 @@ bool TensorContainsInf(const framework::Tensor& tensor) { void TensorToStream(std::ostream& os, const Tensor& tensor, const platform::DeviceContext& dev_ctx) { - // TODO(typhoonzero): serialize to ostream { // the 1st field, uint32_t version constexpr uint32_t version = 0; os.write(reinterpret_cast<const char*>(&version), sizeof(version)); diff --git a/paddle/fluid/operators/detail/CMakeLists.txt b/paddle/fluid/operators/detail/CMakeLists.txt index 0581bd2ac5..94395ccfbc 100644 --- a/paddle/fluid/operators/detail/CMakeLists.txt +++ b/paddle/fluid/operators/detail/CMakeLists.txt @@ -1,3 +1,6 @@ if(WITH_DISTRIBUTE) - grpc_library(sendrecvop_grpc SRCS sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows) + grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc grpc_server.cc PROTO send_recv.proto DEPS lod_tensor selected_rows) + set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") + set_source_files_properties(test_serde.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + cc_test(serde_test SRCS test_serde.cc DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc) endif() diff --git a/paddle/fluid/operators/detail/bytebuffer_stream.cc b/paddle/fluid/operators/detail/bytebuffer_stream.cc new file mode 100644 index 0000000000..a9488156e0 --- /dev/null +++ b/paddle/fluid/operators/detail/bytebuffer_stream.cc @@ -0,0 +1,88 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +// NOTE: This file was originally created by tensorflow +// (https://github.com/tensorflow/tensorflow/) we borrow this +// file and did some modifications so that we can send gRPC +// requests without too much copying of the tensor data. + +#include "bytebuffer_stream.h" + +namespace paddle { +namespace operators { +namespace detail { + +GrpcByteBufferSource::GrpcByteBufferSource() {} + +bool GrpcByteBufferSource::Init(const grpc::ByteBuffer& src) { + cur_ = -1; + left_ = 0; + ptr_ = nullptr; + byte_count_ = 0; + bool ok = src.Dump(&slices_).ok(); + if (!ok) { + slices_.clear(); + } + return ok; +} + +bool GrpcByteBufferSource::Next(const void** data, int* size) { + // Use loop instead of if in case buffer contained empty slices. + while (left_ == 0) { + // Advance to next slice. + cur_++; + if (cur_ >= slices_.size()) { + return false; + } + const ::grpc::Slice& s = slices_[cur_]; + left_ = s.size(); + ptr_ = reinterpret_cast<const char*>(s.begin()); + } + + *data = ptr_; + *size = left_; + byte_count_ += left_; + ptr_ += left_; + left_ = 0; + return true; +} + +void GrpcByteBufferSource::BackUp(int count) { + ptr_ -= count; + left_ += count; + byte_count_ -= count; +} + +bool GrpcByteBufferSource::Skip(int count) { + const void* data; + int size; + while (Next(&data, &size)) { + if (size >= count) { + BackUp(size - count); + return true; + } + // size < count; + count -= size; + } + // error or we have too large count; + return false; +} + +google::protobuf::int64 GrpcByteBufferSource::ByteCount() const { + return byte_count_; +} + +} // namespace detail +} // namespace operators +} // namespace paddle \ No newline at end of file diff --git a/paddle/fluid/operators/detail/bytebuffer_stream.h b/paddle/fluid/operators/detail/bytebuffer_stream.h new file mode 100644 index 0000000000..099deb12d0 --- /dev/null +++ b/paddle/fluid/operators/detail/bytebuffer_stream.h @@ -0,0 +1,51 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +// NOTE: This file was originally created by tensorflow +// (https://github.com/tensorflow/tensorflow/) we borrow this +// file and did some modifications so that we can send gRPC +// requests without too much copying of the tensor data. + +#pragma once + +#include <grpc++/grpc++.h> +#include "google/protobuf/io/coded_stream.h" +#include "google/protobuf/io/zero_copy_stream.h" + +namespace paddle { +namespace operators { +namespace detail { + +// A ZeroCopyInputStream that reads from a grpc::ByteBuffer. +class GrpcByteBufferSource + : public ::google::protobuf::io::ZeroCopyInputStream { + public: + GrpcByteBufferSource(); + bool Init(const ::grpc::ByteBuffer& src); // Can be called multiple times. + bool Next(const void** data, int* size) override; + void BackUp(int count) override; + bool Skip(int count) override; + ::google::protobuf::int64 ByteCount() const override; + + private: + std::vector<::grpc::Slice> slices_; + size_t cur_; // Current slice index. + int left_; // Number of bytes in slices_[cur_] left to yield. + const char* ptr_; // Address of next byte in slices_[cur_] to yield. + ::google::protobuf::int64 byte_count_; +}; + +} // namespace detail +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/detail/proto_encoder_helper.h b/paddle/fluid/operators/detail/proto_encoder_helper.h new file mode 100644 index 0000000000..4a7bfb8bd5 --- /dev/null +++ b/paddle/fluid/operators/detail/proto_encoder_helper.h @@ -0,0 +1,147 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +// NOTE: This file was originally created by tensorflow +// (https://github.com/tensorflow/tensorflow/) we borrow this +// file and did some modifications so that we can send gRPC +// requests without too much copying of the tensor data. + +#pragma once + +#include <grpc++/grpc++.h> +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace operators { +namespace detail { + +char* EncodeVarint32(char* dst, uint32_t v) { + // Operate on characters as unsigneds + unsigned char* ptr = reinterpret_cast<unsigned char*>(dst); + static const int B = 128; + if (v < (1 << 7)) { + *(ptr++) = v; + } else if (v < (1 << 14)) { + *(ptr++) = v | B; + *(ptr++) = v >> 7; + } else if (v < (1 << 21)) { + *(ptr++) = v | B; + *(ptr++) = (v >> 7) | B; + *(ptr++) = v >> 14; + } else if (v < (1 << 28)) { + *(ptr++) = v | B; + *(ptr++) = (v >> 7) | B; + *(ptr++) = (v >> 14) | B; + *(ptr++) = v >> 21; + } else { + *(ptr++) = v | B; + *(ptr++) = (v >> 7) | B; + *(ptr++) = (v >> 14) | B; + *(ptr++) = (v >> 21) | B; + *(ptr++) = v >> 28; + } + return reinterpret_cast<char*>(ptr); +} + +char* EncodeVarint64(char* dst, uint64_t v) { + static const int B = 128; + unsigned char* ptr = reinterpret_cast<unsigned char*>(dst); + while (v >= B) { + *(ptr++) = (v & (B - 1)) | B; + v >>= 7; + } + *(ptr++) = static_cast<unsigned char>(v); + return reinterpret_cast<char*>(ptr); +} + +int VarintLength(uint64_t v) { + int len = 1; + while (v >= 128) { + v >>= 7; + len++; + } + return len; +} + +class ProtoEncodeHelper { + public: + ProtoEncodeHelper(char* buf, int max_size) + : base_(buf), p_(buf), limit_(base_ + max_size) {} + + ~ProtoEncodeHelper() { + // Make sure callers didn't do operations that went over max_size promised + PADDLE_ENFORCE_LE(p_, limit_); + } + + const char* data() const { return base_; } + size_t size() const { return p_ - base_; } + + void WriteUint64(int tag, uint64_t v) { + Encode32(combine(tag, WIRETYPE_VARINT)); + Encode64(v); + } + void WriteBool(int tag, bool v) { + Encode32(combine(tag, WIRETYPE_VARINT)); + EncodeBool(v); + } + void WriteString(int tag, const std::string& v) { + Encode32(combine(tag, WIRETYPE_LENGTH_DELIMITED)); + Encode32(v.size()); + EncodeBytes(v.data(), v.size()); + } + void WriteVarlengthBeginning(int tag, uint32_t len) { + Encode32(combine(tag, WIRETYPE_LENGTH_DELIMITED)); + Encode32(len); + } + void WriteRawBytes(const std::string& v) { EncodeBytes(v.data(), v.size()); } + + private: + // Note: this module's behavior must match the protocol buffer wire encoding + // format. + enum { + WIRETYPE_VARINT = 0, + WIRETYPE_LENGTH_DELIMITED = 2, + }; + static uint32_t combine(uint32_t tag, uint32_t type) { + return ((tag << 3) | type); + } + inline void Encode32(uint32_t v) { + if (v < 128) { + // Fast path for single-byte values. Many of the calls will use a + // constant value for v, so the comparison will get optimized away + // when Encode32 is inlined into the caller. + *p_ = v; + p_++; + } else { + p_ = EncodeVarint32(p_, v); + } + } + void Encode64(uint64_t v) { p_ = EncodeVarint64(p_, v); } + void EncodeBool(bool v) { + *p_ = (v ? 1 : 0); // Equal to varint32 encoding of 0 or 1 + p_++; + } + void EncodeBytes(const char* bytes, int N) { + memcpy(p_, bytes, N); + p_ += N; + } + + char* base_; + char* p_; + char* limit_; // Just for CHECKs +}; + +} // detail +} // operators +} // paddle diff --git a/paddle/fluid/operators/detail/send_recv.proto b/paddle/fluid/operators/detail/send_recv.proto index 8f962b4c69..b0215d4a80 100644 --- a/paddle/fluid/operators/detail/send_recv.proto +++ b/paddle/fluid/operators/detail/send_recv.proto @@ -33,10 +33,34 @@ enum VarType { } message VariableMessage { + enum Type { + // Pod Types + BOOL = 0; + INT16 = 1; + INT32 = 2; + INT64 = 3; + FP16 = 4; + FP32 = 5; + FP64 = 6; + } + + message LodData { repeated int64 lod_data = 1; } + string varname = 1; // TODO(Yancey1989): reference framework::proto::VarDesc::VarType VarType type = 2; - bytes serialized = 3; + // bool persistable is not needed for sending. + // tensor info: + Type data_type = 3; + repeated int64 dims = 4; + + // lod details: + int64 lod_level = 5; + repeated LodData lod = 6; + // tensor data + bytes serialized = 7; + // selected_rows data + bytes rows = 8; } message VoidMessage {} diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc index 169fd40fd9..64d181f408 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc @@ -13,6 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/detail/sendrecvop_utils.h" +#include "google/protobuf/io/coded_stream.h" +#include "google/protobuf/io/zero_copy_stream.h" +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/operators/detail/bytebuffer_stream.h" +#include "paddle/fluid/operators/detail/proto_encoder_helper.h" namespace paddle { namespace operators { @@ -63,6 +68,242 @@ void DeserializeFromMessage(const sendrecv::VariableMessage& msg, } } +void SerializeToByteBuffer(const std::string& name, framework::Variable* var, + const platform::DeviceContext& ctx, + ::grpc::ByteBuffer* msg) { + using VarMsg = sendrecv::VariableMessage; + sendrecv::VariableMessage request; + std::string header; + request.AppendToString(&header); + // When using GPU, need to free the copied CPU buffer + // when the ByteBuffer destroies + // TODO(typhoonzero): add unref here, if we have dependent + // parallelism execution, need to know when to free the tensor. + DestroyCallback destroy_callback = [](void* backing) {}; + + void* buf = malloc(1024); + void* payload; + size_t payload_size; + ProtoEncodeHelper e((char*)buf, 1024); + e.WriteString(VarMsg::kVarnameFieldNumber, name); + if (var->IsType<framework::LoDTensor>()) { + e.WriteUint64(VarMsg::kTypeFieldNumber, 0); + } else if (var->IsType<framework::SelectedRows>()) { + e.WriteUint64(VarMsg::kTypeFieldNumber, 1); + } + + switch (framework::ToVarType(var->Type())) { + case framework::proto::VarType_Type_LOD_TENSOR: { + auto tensor = var->Get<framework::LoDTensor>(); + e.WriteUint64(VarMsg::kDataTypeFieldNumber, + framework::ToDataType(tensor.type())); + for (auto& dim : framework::vectorize(tensor.dims())) { + e.WriteUint64(VarMsg::kDimsFieldNumber, dim); + } + auto lod = tensor.lod(); // std::vector<Vector<size_t>> + if (lod.size() > 0) { + e.WriteUint64(VarMsg::kLodLevelFieldNumber, lod.size()); + + for (auto& each : lod) { + e.WriteVarlengthBeginning(VarMsg::kLodFieldNumber, + 2 + // tag + varintlength of submessage + 1 + // kLodDataFieldNumber + each.size()); + // auto copied from GPU + for (auto& d : each) { + e.WriteUint64(VarMsg::LodData::kLodDataFieldNumber, d); + } + } + } + if (platform::is_gpu_place(ctx.GetPlace())) { +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE(platform::is_gpu_place(tensor.place())); + platform::CPUPlace cpu; + auto& gpu_dev_ctx = + static_cast<const platform::CUDADeviceContext&>(ctx); + auto copy_size = tensor.memory_size(); + payload = memory::Alloc(cpu, copy_size); + memory::Copy(cpu, payload, + boost::get<platform::CUDAPlace>(tensor.place()), + reinterpret_cast<const void*>(tensor.data<void>()), + copy_size, gpu_dev_ctx.stream()); + destroy_callback = [](void* backing) { + std::cout << "destroy payload" << std::endl; + platform::CPUPlace cpu; + memory::Free(cpu, backing); + }; +#endif + } else { + payload = tensor.data<void>(); + } + payload_size = tensor.memory_size(); + + std::string tmp(reinterpret_cast<char*>(payload), payload_size); + for (int i = 0; i < tmp.size(); ++i) { + printf("%02X ", tmp.data()[i]); + } + printf("\n"); + e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size); + } break; + case framework::proto::VarType_Type_SELECTED_ROWS: { + // TODO(typhoonzero): selectedrows implement should not use unique_ptr + auto* slr = var->GetMutable<framework::SelectedRows>(); + e.WriteUint64(VarMsg::kDataTypeFieldNumber, + framework::ToDataType(slr->value().type())); + for (auto& dim : framework::vectorize(slr->value().dims())) { + e.WriteUint64(VarMsg::kDimsFieldNumber, dim); + } + e.WriteUint64(VarMsg::kLodLevelFieldNumber, 0); + auto* tensor = slr->mutable_value(); + if (platform::is_gpu_place(ctx.GetPlace())) { +#ifdef PADDLE_WITH_CUDA + platform::CPUPlace cpu; + auto& gpu_dev_ctx = + static_cast<const platform::CUDADeviceContext&>(ctx); + auto copy_size = tensor->memory_size(); + payload = memory::Alloc(cpu, copy_size); + memory::Copy(cpu, payload, + boost::get<platform::CUDAPlace>(tensor->place()), + reinterpret_cast<const void*>(tensor->data<void>()), + copy_size, gpu_dev_ctx.stream()); + ctx.Wait(); + float* ttt = reinterpret_cast<float*>(payload); + for (int i = 0; i < copy_size / 4; i++) { + std::cout << "copied to cpu: " << ttt[i] << std::endl; + } + destroy_callback = [](void* backing) { + std::cout << "destroy..." << std::endl; + // platform::CPUPlace cpu; + // memory::Free(cpu, backing); + }; +#endif + } else { + payload = slr->mutable_value()->data<void>(); + } + payload_size = tensor->memory_size(); + e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size); + } break; + default: + PADDLE_THROW("Serialize does not support type: %s", + typeid(var->Type()).name()); + break; + } + // steal reference of tensor data + ::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows + int num_slices = 2; // only SelectedRows have rows buffer + slices[0] = ::grpc::Slice(e.size()); + memcpy(const_cast<uint8_t*>(slices[0].begin()), e.data(), e.size()); + slices[1] = ::grpc::Slice( + grpc_slice_new_with_user_data(payload, payload_size, destroy_callback, + static_cast<char*>(payload)), + ::grpc::Slice::STEAL_REF); + + if (framework::ToVarType(var->Type()) == + framework::proto::VarType_Type_SELECTED_ROWS) { + auto* slr = var->GetMutable<framework::SelectedRows>(); + + ProtoEncodeHelper e2((char*)buf, 128); + // NOTE: rows is of type int64_t + size_t rows_memory_size = + slr->rows().capacity() * framework::SizeOfType(typeid(int64_t)); + e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size); + slices[2] = ::grpc::Slice(e2.size()); + memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size()); + + slices[3] = ::grpc::Slice( + grpc_slice_new_with_user_data( + const_cast<void*>( + reinterpret_cast<const void*>(slr->rows().data())), + rows_memory_size, + [](void* backing) { + // TODO(typhoonzero): add unref here, same as above. + }, + const_cast<char*>( + reinterpret_cast<const char*>(slr->rows().data()))), + ::grpc::Slice::STEAL_REF); + num_slices = 4; + } + + ::grpc::ByteBuffer tmp(&slices[0], num_slices); + msg->Swap(&tmp); +} + +void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, + const platform::DeviceContext& ctx, + framework::Variable* var) { + sendrecv::VariableMessage meta; + GrpcByteBufferSource source; + source.Init(msg); + ::google::protobuf::io::CodedInputStream input(&source); + // do zerocopy parsing + PADDLE_ENFORCE(meta.ParseFromCodedStream(&input)); + PADDLE_ENFORCE(input.ConsumedEntireMessage()); + // dims is needed by both tensor and selectedrows + std::vector<int> vecdims; + for (auto& d : meta.dims()) { + vecdims.push_back(d); + } + framework::DDim dims = framework::make_ddim(vecdims); + + if (meta.type() == sendrecv::LOD_TENSOR) { + auto* tensor = var->GetMutable<framework::LoDTensor>(); + tensor->Resize(dims); + void* tensor_data = tensor->mutable_data( + ctx.GetPlace(), + paddle::operators::detail::ToTypeIndex(meta.data_type())); + framework::LoD lod; + for (int i = 0; i < meta.lod_level(); ++i) { + framework::Vector<size_t> v; + for (int j = 0; j < meta.lod(i).lod_data_size(); ++j) { + v.push_back(meta.lod(i).lod_data(j)); + } + lod.push_back(v); + } + tensor->set_lod(lod); + // How to avoid copying and use the message buffer directly? + // Maybe need to find a way to release all memory except tensor content. + if (platform::is_gpu_place(ctx.GetPlace())) { +#ifdef PADDLE_WITH_CUDA + platform::CPUPlace cpu; + auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx); + memory::Copy(boost::get<platform::CUDAPlace>(tensor->place()), + tensor_data, cpu, + reinterpret_cast<const void*>(meta.serialized().data()), + meta.serialized().size(), gpu_dev_ctx.stream()); +#endif + } else { + memcpy(tensor_data, + reinterpret_cast<const void*>(meta.serialized().data()), + meta.serialized().size()); + } + } else if (meta.type() == sendrecv::SELECTED_ROWS) { + auto* slr = var->GetMutable<framework::SelectedRows>(); + auto* tensor = slr->mutable_value(); + int64_t* rows_data = slr->mutable_rows()->data(); + tensor->Resize(dims); + void* tensor_data = tensor->mutable_data( + ctx.GetPlace(), + paddle::operators::detail::ToTypeIndex(meta.data_type())); + if (platform::is_gpu_place(ctx.GetPlace())) { +#ifdef PADDLE_WITH_CUDA + platform::CPUPlace cpu; + auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx); + memory::Copy(boost::get<platform::CUDAPlace>(tensor->place()), + tensor_data, cpu, + reinterpret_cast<const void*>(meta.serialized().data()), + meta.serialized().size(), gpu_dev_ctx.stream()); +#endif + } else { + memcpy(tensor_data, + reinterpret_cast<const void*>(meta.serialized().data()), + meta.serialized().size()); + } + // copy rows CPU data, GPU data will be copied lazly + memcpy(rows_data, reinterpret_cast<const void*>(meta.rows().data()), + meta.rows().size()); + } +} + } // namespace detail } // namespace operators -} // namespace paddle +} // namespace paddle \ No newline at end of file diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.h b/paddle/fluid/operators/detail/sendrecvop_utils.h index 670d0e1624..65704db5ae 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.h +++ b/paddle/fluid/operators/detail/sendrecvop_utils.h @@ -33,6 +33,14 @@ namespace detail { #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" #define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" +typedef void (*DestroyCallback)(void*); + +inline int64_t GetTimestamp() { + return std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::system_clock::now().time_since_epoch()) + .count(); +} + void SerializeToMessage(const std::string& name, const framework::Variable* var, const platform::DeviceContext& ctx, sendrecv::VariableMessage* msg); @@ -40,6 +48,32 @@ void SerializeToMessage(const std::string& name, const framework::Variable* var, void DeserializeFromMessage(const sendrecv::VariableMessage& msg, const platform::DeviceContext& ctx, framework::Variable* var); + +void SerializeToByteBuffer(const std::string& name, framework::Variable* var, + const platform::DeviceContext& ctx, + ::grpc::ByteBuffer* msg); + +void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, + const platform::DeviceContext& ctx, + framework::Variable* var); + +inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) { + switch (type) { + case sendrecv::VariableMessage::FP32: + return typeid(float); // NOLINT + case sendrecv::VariableMessage::FP64: + return typeid(double); // NOLINT + case sendrecv::VariableMessage::INT32: + return typeid(int); // NOLINT + case sendrecv::VariableMessage::INT64: + return typeid(int64_t); // NOLINT + case sendrecv::VariableMessage::BOOL: + return typeid(bool); // NOLINT + default: + PADDLE_THROW("Not support type %d", type); + } +} + } // namespace detail } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/detail/test_serde.cc b/paddle/fluid/operators/detail/test_serde.cc new file mode 100644 index 0000000000..8054c89ecf --- /dev/null +++ b/paddle/fluid/operators/detail/test_serde.cc @@ -0,0 +1,195 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include <unistd.h> +#include <string> +#include <thread> + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/operators/detail/sendrecvop_utils.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/string/printf.h" + +namespace framework = paddle::framework; +namespace platform = paddle::platform; +namespace operators = paddle::operators; +namespace math = paddle::operators::math; +namespace memory = paddle::memory; + +void RunSerdeTestTensor(platform::Place place) { + // serialize var to ByteBuffer + framework::Variable var; + auto* tensor = var.GetMutable<framework::LoDTensor>(); + tensor->Resize(framework::make_ddim({4, 8, 4, 2})); + framework::LoD lod; + lod.push_back(framework::Vector<size_t>({1, 3, 8})); + tensor->set_lod(lod); + int tensor_numel = 4 * 8 * 4 * 2; + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& ctx = *pool.Get(place); + float* orig_tensor_data = tensor->mutable_data<float>(place); + math::set_constant(ctx, tensor, 31.9); + + ::grpc::ByteBuffer msg; + operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg); + EXPECT_GT(msg.Length(), 0); + + // deserialize + std::vector<::grpc::Slice> slices; + (void)msg.Dump(&slices); + std::string tmp; + for (const auto& s : slices) { + tmp.append(reinterpret_cast<const char*>(s.begin()), s.size()); + } + sendrecv::VariableMessage varmsg; + EXPECT_TRUE(varmsg.ParseFromString(tmp)); + EXPECT_EQ(varmsg.varname(), "myvar"); + EXPECT_EQ(varmsg.type(), 0); + EXPECT_EQ(varmsg.dims()[0], 4); + EXPECT_EQ(varmsg.dims()[1], 8); + EXPECT_EQ(varmsg.dims()[2], 4); + EXPECT_EQ(varmsg.dims()[3], 2); + EXPECT_EQ(varmsg.lod_level(), 1); + EXPECT_EQ(varmsg.lod(0).lod_data(0), 1); + EXPECT_EQ(varmsg.lod(0).lod_data(1), 3); + EXPECT_EQ(varmsg.lod(0).lod_data(2), 8); + + const float* tensor_data = + reinterpret_cast<const float*>(varmsg.serialized().data()); + for (int i = 0; i < varmsg.serialized().size(); ++i) { + printf("%02X ", varmsg.serialized().data()[i]); + } + printf("\n"); + for (int i = 0; i < tensor_numel; ++i) { + std::cout << "#####tensor data: " << tensor_data[i] << std::endl; + EXPECT_EQ(tensor_data[i], orig_tensor_data[i]); + std::cout << "test end 1 " << std::endl; + } + std::cout << "tensor data end " << std::endl; + + // deserialize zero-copy + framework::Variable var2; + operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2); + auto tensor2 = var2.Get<framework::LoDTensor>(); + float* tensor_data2 = nullptr; + framework::Tensor tmp_tensor; + + if (platform::is_gpu_place(ctx.GetPlace())) { + platform::CPUPlace cpu; + framework::TensorCopy(tensor2, cpu, &tmp_tensor); + tensor_data2 = tmp_tensor.data<float>(); + } else { + tensor_data2 = const_cast<float*>(tensor2.data<float>()); + } + + EXPECT_EQ(varmsg.lod_level(), 1); + EXPECT_EQ(varmsg.lod(0).lod_data(0), 1); + EXPECT_EQ(varmsg.lod(0).lod_data(1), 3); + EXPECT_EQ(varmsg.lod(0).lod_data(2), 8); + for (int i = 0; i < tensor_numel; ++i) + EXPECT_EQ(tensor_data2[i], orig_tensor_data[i]); +} + +void RunSerdeTestSelectedRows(platform::Place place) { + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& ctx = *pool.Get(place); + + // serialize var to ByteBuffer + framework::Variable var; + auto* slr = var.GetMutable<framework::SelectedRows>(); + auto* tensor = slr->mutable_value(); + auto* rows = slr->mutable_rows(); + + tensor->Resize(framework::make_ddim({2, 10})); + int tensor_numel = 2 * 10; + float* orig_tensor_data = tensor->mutable_data<float>(place); + math::set_constant(ctx, tensor, 32.7); + rows->push_back(3); + rows->push_back(10); + + ::grpc::ByteBuffer msg; + operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg); + EXPECT_GT(msg.Length(), 0); + + // deserialize + std::vector<::grpc::Slice> slices; + (void)msg.Dump(&slices); + std::string tmp; + for (const auto& s : slices) { + tmp.append(reinterpret_cast<const char*>(s.begin()), s.size()); + } + sendrecv::VariableMessage varmsg; + EXPECT_TRUE(varmsg.ParseFromString(tmp)); + + EXPECT_EQ(varmsg.varname(), "myvar"); + EXPECT_EQ(varmsg.type(), 1); + + const float* tensor_data = + reinterpret_cast<const float*>(varmsg.serialized().data()); + const int64_t* rows_data = + reinterpret_cast<const int64_t*>(varmsg.rows().data()); + for (int i = 0; i < tensor_numel; ++i) { + EXPECT_EQ(tensor_data[i], orig_tensor_data[i]); + } + EXPECT_EQ(rows_data[0], 3); + EXPECT_EQ(rows_data[1], 10); + // deserialize zero-copy + framework::Variable var2; + operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2); + + auto* slr2 = var2.GetMutable<framework::SelectedRows>(); + auto* tensor2 = slr2->mutable_value(); + auto* rows2 = slr2->mutable_rows(); + float* tensor_data2 = nullptr; + framework::Tensor tmp_tensor; + + if (platform::is_gpu_place(ctx.GetPlace())) { + platform::CPUPlace cpu; + framework::TensorCopy(*tensor2, cpu, &tmp_tensor); + tensor_data2 = tmp_tensor.data<float>(); + } else { + tensor_data2 = const_cast<float*>(tensor2->data<float>()); + } + const int64_t* rows_data2 = rows2->data(); + + for (int i = 0; i < tensor_numel; ++i) { + EXPECT_EQ(tensor_data2[i], orig_tensor_data[i]); + } + EXPECT_EQ(rows_data2[0], 3); + EXPECT_EQ(rows_data2[1], 10); +} + +// TEST(SelectedRows, CPU) { +// platform::CPUPlace place; +// RunSerdeTestSelectedRows(place); +// } + +// TEST(SelectedRows, GPU) { +// platform::CUDAPlace place; +// RunSerdeTestSelectedRows(place); +// } + +TEST(Tensor, CPU) { + platform::CPUPlace place; + RunSerdeTestTensor(place); +} + +TEST(Tensor, GPU) { + platform::CUDAPlace place; + RunSerdeTestTensor(place); +} \ No newline at end of file From a7d236d608e388a5a23061b77f7a1417993f3d7e Mon Sep 17 00:00:00 2001 From: weixing02 <564445201@qq.com> Date: Fri, 9 Mar 2018 11:24:20 +0800 Subject: [PATCH 37/40] Move 2 pictures from /v2 to /fluid (#8846) --- doc/{v2 => fluid}/howto/optimization/pprof_1.png | Bin doc/{v2 => fluid}/howto/optimization/pprof_2.png | Bin 2 files changed, 0 insertions(+), 0 deletions(-) rename doc/{v2 => fluid}/howto/optimization/pprof_1.png (100%) rename doc/{v2 => fluid}/howto/optimization/pprof_2.png (100%) diff --git a/doc/v2/howto/optimization/pprof_1.png b/doc/fluid/howto/optimization/pprof_1.png similarity index 100% rename from doc/v2/howto/optimization/pprof_1.png rename to doc/fluid/howto/optimization/pprof_1.png diff --git a/doc/v2/howto/optimization/pprof_2.png b/doc/fluid/howto/optimization/pprof_2.png similarity index 100% rename from doc/v2/howto/optimization/pprof_2.png rename to doc/fluid/howto/optimization/pprof_2.png From 90215b784487efad690b05749c34d03cc984cbb5 Mon Sep 17 00:00:00 2001 From: kexinzhao <kexin.zhao.paddle@gmail.com> Date: Thu, 8 Mar 2018 20:05:45 -0800 Subject: [PATCH 38/40] Add float16 GEMM math function on GPU (#8695) * test cpu float16 data transform * add isnan etc * small fix * fix containsNAN test error * add data_type transform GPU test * add float16 GPU example * fix error * fix GPU test error * initial commit * fix error * small fix * add more gemm fp16 tests * fix error * add utility function --- paddle/fluid/operators/math/math_function.cc | 39 ++ paddle/fluid/operators/math/math_function.cu | 108 +++++ .../operators/math/math_function_test.cu | 396 +++++++++++++----- paddle/fluid/platform/dynload/cublas.h | 3 + 4 files changed, 449 insertions(+), 97 deletions(-) diff --git a/paddle/fluid/operators/math/math_function.cc b/paddle/fluid/operators/math/math_function.cc index f7f33917d7..35d251f71a 100644 --- a/paddle/fluid/operators/math/math_function.cc +++ b/paddle/fluid/operators/math/math_function.cc @@ -15,11 +15,23 @@ limitations under the License. */ #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/operators/math/math_function_impl.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { namespace math { +using float16 = paddle::platform::float16; + +template <> +void gemm<platform::CPUDeviceContext, float16>( + const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, + const float16 alpha, const float16* A, const float16* B, const float16 beta, + float16* C) { + PADDLE_THROW("float16 GEMM not supported on CPU"); +} + template <> void gemm<platform::CPUDeviceContext, float>( const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA, @@ -46,6 +58,15 @@ void gemm<platform::CPUDeviceContext, double>( beta, C, ldc); } +template <> +void gemm<platform::CPUDeviceContext, float16>( + const platform::CPUDeviceContext& context, const bool transA, + const bool transB, const int M, const int N, const int K, + const float16 alpha, const float16* A, const int lda, const float16* B, + const int ldb, const float16 beta, float16* C, const int ldc) { + PADDLE_THROW("float16 GEMM not supported on CPU"); +} + template <> void gemm<platform::CPUDeviceContext, float>( const platform::CPUDeviceContext& context, const bool transA, @@ -68,6 +89,15 @@ void gemm<platform::CPUDeviceContext, double>( lda, B, ldb, beta, C, ldc); } +template <> +void matmul<platform::CPUDeviceContext, float16>( + const platform::CPUDeviceContext& context, + const framework::Tensor& matrix_a, bool trans_a, + const framework::Tensor& matrix_b, bool trans_b, float16 alpha, + framework::Tensor* matrix_out, float16 beta) { + PADDLE_THROW("float16 matmul not supported on CPU"); +} + template <> void matmul<platform::CPUDeviceContext, float>( const platform::CPUDeviceContext& context, @@ -126,6 +156,15 @@ void matmul<platform::CPUDeviceContext, double>( matrix_b.data<double>(), beta, matrix_out->data<double>()); } +template <> +void batched_gemm<platform::CPUDeviceContext, float16>( + const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, + const float16 alpha, const float16* A, const float16* B, const float16 beta, + float16* C, const int batchCount, const int strideA, const int strideB) { + PADDLE_THROW("float16 batched_gemm not supported on CPU"); +} + #ifdef PADDLE_WITH_MKLML // Use cblas_{s,d}gemm_batched if available: Run with 1 group of size batchSize. template <> diff --git a/paddle/fluid/operators/math/math_function.cu b/paddle/fluid/operators/math/math_function.cu index f8d0349ac5..36655508be 100644 --- a/paddle/fluid/operators/math/math_function.cu +++ b/paddle/fluid/operators/math/math_function.cu @@ -16,11 +16,40 @@ limitations under the License. */ #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function_impl.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { namespace math { +using float16 = paddle::platform::float16; + +template <> +void gemm<platform::CUDADeviceContext, float16>( + const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, + const float16 alpha, const float16* A, const float16* B, const float16 beta, + float16* C) { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + + const half h_alpha = static_cast<const half>(alpha); + const half h_beta = static_cast<const half>(beta); + const half* h_A = reinterpret_cast<const half*>(A); + const half* h_B = reinterpret_cast<const half*>(B); + half* h_C = reinterpret_cast<half*>(C); + + PADDLE_ENFORCE(platform::dynload::cublasHgemm( + context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb, + h_A, lda, &h_beta, h_C, N)); +} + template <> void gemm<platform::CUDADeviceContext, float>( const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA, @@ -60,6 +89,28 @@ void gemm<platform::CUDADeviceContext, double>( lda, &beta, C, N)); } +template <> +void gemm<platform::CUDADeviceContext, float16>( + const platform::CUDADeviceContext& context, const bool transA, + const bool transB, const int M, const int N, const int K, + const float16 alpha, const float16* A, const int lda, const float16* B, + const int ldb, const float16 beta, float16* C, const int ldc) { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T; + + const half h_alpha = static_cast<const half>(alpha); + const half h_beta = static_cast<const half>(beta); + const half* h_A = reinterpret_cast<const half*>(A); + const half* h_B = reinterpret_cast<const half*>(B); + half* h_C = reinterpret_cast<half*>(C); + + PADDLE_ENFORCE(platform::dynload::cublasHgemm( + context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb, + h_A, lda, &h_beta, h_C, ldc)); +} + template <> void gemm<platform::CUDADeviceContext, float>( const platform::CUDADeviceContext& context, const bool transA, @@ -90,6 +141,35 @@ void gemm<platform::CUDADeviceContext, double>( lda, &beta, C, ldc)); } +template <> +void matmul<platform::CUDADeviceContext, float16>( + const platform::CUDADeviceContext& context, + const framework::Tensor& matrix_a, bool trans_a, + const framework::Tensor& matrix_b, bool trans_b, float16 alpha, + framework::Tensor* matrix_out, float16 beta) { + auto dim_a = matrix_a.dims(); + auto dim_b = matrix_b.dims(); + auto dim_out = matrix_out->dims(); + PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, + "The input and output of matmul be matrix"); + + PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) && + platform::is_gpu_place(matrix_b.place()) && + platform::is_gpu_place(matrix_out->place()), + "Matrix must all be in CUDAPlace"); + + int M = dim_out[0]; + int N = dim_out[1]; + int K = (trans_a == false) ? dim_a[1] : dim_a[0]; + + CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; + + gemm<platform::CUDADeviceContext, float16>( + context, transA, transB, M, N, K, alpha, matrix_a.data<float16>(), + matrix_b.data<float16>(), beta, matrix_out->data<float16>()); +} + template <> void matmul<platform::CUDADeviceContext, float>( const platform::CUDADeviceContext& context, @@ -148,6 +228,34 @@ void matmul<platform::CUDADeviceContext, double>( matrix_b.data<double>(), beta, matrix_out->data<double>()); } +template <> +void batched_gemm<platform::CUDADeviceContext, float16>( + const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, + const float16 alpha, const float16* A, const float16* B, const float16 beta, + float16* C, const int batchCount, const int strideA, const int strideB) { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + const int strideC = M * N; + + const half h_alpha = static_cast<const half>(alpha); + const half h_beta = static_cast<const half>(beta); + const half* h_A = reinterpret_cast<const half*>(A); + const half* h_B = reinterpret_cast<const half*>(B); + half* h_C = reinterpret_cast<half*>(C); + + PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched( + context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb, + strideB, h_A, lda, strideA, &h_beta, h_C, ldc, strideC, batchCount)); +} + template <> void batched_gemm<platform::CUDADeviceContext, float>( const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA, diff --git a/paddle/fluid/operators/math/math_function_test.cu b/paddle/fluid/operators/math/math_function_test.cu index 207d6a87bc..442e62d563 100644 --- a/paddle/fluid/operators/math/math_function_test.cu +++ b/paddle/fluid/operators/math/math_function_test.cu @@ -14,30 +14,41 @@ #include "gtest/gtest.h" #include "paddle/fluid/operators/math/math_function.h" -TEST(math_function, notrans_mul_trans) { - paddle::framework::Tensor input1; - paddle::framework::Tensor input1_gpu; - paddle::framework::Tensor input2_gpu; - paddle::framework::Tensor out_gpu; - paddle::framework::Tensor out; - - auto* cpu_place = new paddle::platform::CPUPlace(); - float* input1_ptr = input1.mutable_data<float>({2, 3}, *cpu_place); +void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size, + const std::vector<float>& data) { + PADDLE_ENFORCE_EQ(size, data.size()); + for (size_t i = 0; i < data.size(); ++i) { + in_ptr[i] = paddle::platform::float16(data[i]); + } +} + +TEST(math_function, notrans_mul_trans_fp32) { + using namespace paddle::framework; + using namespace paddle::platform; + + Tensor input1; + Tensor input1_gpu; + Tensor input2_gpu; + Tensor out_gpu; + Tensor out; + + CPUPlace cpu_place; + CUDAPlace gpu_place(0); + CUDADeviceContext context(gpu_place); + + float* input1_ptr = input1.mutable_data<float>({2, 3}, cpu_place); float arr[6] = {0, 1, 2, 3, 4, 5}; memcpy(input1_ptr, arr, 6 * sizeof(float)); - auto* gpu_place = new paddle::platform::CUDAPlace(0); - paddle::platform::CUDADeviceContext context(*gpu_place); - - paddle::framework::TensorCopy(input1, *gpu_place, context, &input1_gpu); - paddle::framework::TensorCopy(input1, *gpu_place, context, &input2_gpu); + TensorCopy(input1, gpu_place, context, &input1_gpu); + TensorCopy(input1, gpu_place, context, &input2_gpu); - out_gpu.mutable_data<float>({2, 2}, *gpu_place); + out_gpu.mutable_data<float>({2, 2}, gpu_place); - paddle::operators::math::matmul<paddle::platform::CUDADeviceContext, float>( + paddle::operators::math::matmul<CUDADeviceContext, float>( context, input1_gpu, false, input2_gpu, true, 1, &out_gpu, 0); - paddle::framework::TensorCopy(out_gpu, *cpu_place, context, &out); + TensorCopy(out_gpu, cpu_place, context, &out); float* out_ptr = out.data<float>(); context.Wait(); @@ -45,33 +56,71 @@ TEST(math_function, notrans_mul_trans) { EXPECT_EQ(out_ptr[1], 14); EXPECT_EQ(out_ptr[2], 14); EXPECT_EQ(out_ptr[3], 50); - delete gpu_place; } -TEST(math_function, trans_mul_notrans) { - paddle::framework::Tensor input1; - paddle::framework::Tensor input1_gpu; - paddle::framework::Tensor input2_gpu; - paddle::framework::Tensor out_gpu; - paddle::framework::Tensor out; +TEST(math_function, notrans_mul_trans_fp16) { + using namespace paddle::framework; + using namespace paddle::platform; + + Tensor input1; + Tensor input1_gpu; + Tensor input2_gpu; + Tensor out_gpu; + Tensor out; + + CPUPlace cpu_place; + CUDAPlace gpu_place(0); + CUDADeviceContext context(gpu_place); + + float16* input1_ptr = input1.mutable_data<float16>({2, 3}, cpu_place); + fill_fp16_data(input1_ptr, input1.numel(), {0, 1, 2, 3, 4, 5}); + + TensorCopy(input1, gpu_place, context, &input1_gpu); + TensorCopy(input1, gpu_place, context, &input2_gpu); + + out_gpu.mutable_data<float16>({2, 2}, gpu_place); + + paddle::operators::math::matmul<CUDADeviceContext, float16>( + context, input1_gpu, false, input2_gpu, true, float16(1), &out_gpu, + float16(0)); + + TensorCopy(out_gpu, cpu_place, context, &out); + + float16* out_ptr = out.data<float16>(); + context.Wait(); + EXPECT_EQ(static_cast<float>(out_ptr[0]), 5); + EXPECT_EQ(static_cast<float>(out_ptr[1]), 14); + EXPECT_EQ(static_cast<float>(out_ptr[2]), 14); + EXPECT_EQ(static_cast<float>(out_ptr[3]), 50); +} + +TEST(math_function, trans_mul_notrans_fp32) { + using namespace paddle::framework; + using namespace paddle::platform; + + Tensor input1; + Tensor input1_gpu; + Tensor input2_gpu; + Tensor out_gpu; + Tensor out; + + CPUPlace cpu_place; + CUDAPlace gpu_place(0); + CUDADeviceContext context(gpu_place); - auto* cpu_place = new paddle::platform::CPUPlace(); - float* input1_ptr = input1.mutable_data<float>({2, 3}, *cpu_place); + float* input1_ptr = input1.mutable_data<float>({2, 3}, cpu_place); float arr[6] = {0, 1, 2, 3, 4, 5}; memcpy(input1_ptr, arr, 6 * sizeof(float)); - auto* gpu_place = new paddle::platform::CUDAPlace(0); - paddle::platform::CUDADeviceContext context(*gpu_place); + TensorCopy(input1, gpu_place, context, &input1_gpu); + TensorCopy(input1, gpu_place, context, &input2_gpu); - paddle::framework::TensorCopy(input1, *gpu_place, context, &input1_gpu); - paddle::framework::TensorCopy(input1, *gpu_place, context, &input2_gpu); - - out_gpu.mutable_data<float>({3, 3}, *gpu_place); + out_gpu.mutable_data<float>({3, 3}, gpu_place); paddle::operators::math::matmul<paddle::platform::CUDADeviceContext, float>( context, input1_gpu, true, input2_gpu, false, 1, &out_gpu, 0); - paddle::framework::TensorCopy(out_gpu, *cpu_place, context, &out); + TensorCopy(out_gpu, cpu_place, context, &out); float* out_ptr = out.data<float>(); context.Wait(); @@ -84,45 +133,88 @@ TEST(math_function, trans_mul_notrans) { EXPECT_EQ(out_ptr[6], 15); EXPECT_EQ(out_ptr[7], 22); EXPECT_EQ(out_ptr[8], 29); - delete gpu_place; } -TEST(math_function, gemm_notrans_cublas) { - paddle::framework::Tensor input1; - paddle::framework::Tensor input2; - paddle::framework::Tensor input3; - paddle::framework::Tensor input1_gpu; - paddle::framework::Tensor input2_gpu; - paddle::framework::Tensor input3_gpu; +TEST(math_function, trans_mul_notrans_fp16) { + using namespace paddle::framework; + using namespace paddle::platform; + + Tensor input1; + Tensor input1_gpu; + Tensor input2_gpu; + Tensor out_gpu; + Tensor out; + + CPUPlace cpu_place; + CUDAPlace gpu_place(0); + CUDADeviceContext context(gpu_place); + + float16* input1_ptr = input1.mutable_data<float16>({2, 3}, cpu_place); + fill_fp16_data(input1_ptr, input1.numel(), {0, 1, 2, 3, 4, 5}); + + TensorCopy(input1, gpu_place, context, &input1_gpu); + TensorCopy(input1, gpu_place, context, &input2_gpu); + + out_gpu.mutable_data<float16>({3, 3}, gpu_place); + + paddle::operators::math::matmul<paddle::platform::CUDADeviceContext, float16>( + context, input1_gpu, true, input2_gpu, false, float16(1), &out_gpu, + float16(0)); + + TensorCopy(out_gpu, cpu_place, context, &out); + + float16* out_ptr = out.data<float16>(); + context.Wait(); + EXPECT_EQ(static_cast<float>(out_ptr[0]), 9); + EXPECT_EQ(static_cast<float>(out_ptr[1]), 12); + EXPECT_EQ(static_cast<float>(out_ptr[2]), 15); + EXPECT_EQ(static_cast<float>(out_ptr[3]), 12); + EXPECT_EQ(static_cast<float>(out_ptr[4]), 17); + EXPECT_EQ(static_cast<float>(out_ptr[5]), 22); + EXPECT_EQ(static_cast<float>(out_ptr[6]), 15); + EXPECT_EQ(static_cast<float>(out_ptr[7]), 22); + EXPECT_EQ(static_cast<float>(out_ptr[8]), 29); +} + +TEST(math_function, gemm_notrans_cublas_fp32) { + using namespace paddle::framework; + using namespace paddle::platform; + + Tensor input1; + Tensor input2; + Tensor input3; + Tensor input1_gpu; + Tensor input2_gpu; + Tensor input3_gpu; + + CPUPlace cpu_place; + CUDAPlace gpu_place(0); + CUDADeviceContext context(gpu_place); int m = 2; int n = 3; int k = 3; - auto* cpu_place = new paddle::platform::CPUPlace(); - float* input1_ptr = input1.mutable_data<float>({2, 3}, *cpu_place); + float* input1_ptr = input1.mutable_data<float>({2, 3}, cpu_place); float arr1[6] = {0, 1, 2, 3, 4, 5}; memcpy(input1_ptr, arr1, 6 * sizeof(float)); - float* input2_ptr = input2.mutable_data<float>({3, 4}, *cpu_place); + float* input2_ptr = input2.mutable_data<float>({3, 4}, cpu_place); float arr2[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; memcpy(input2_ptr, arr2, 12 * sizeof(float)); - float* input3_ptr = input3.mutable_data<float>({2, 4}, *cpu_place); + float* input3_ptr = input3.mutable_data<float>({2, 4}, cpu_place); float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7}; memcpy(input3_ptr, arr3, 8 * sizeof(float)); - auto* gpu_place = new paddle::platform::CUDAPlace(0); - paddle::platform::CUDADeviceContext context(*gpu_place); - - paddle::framework::TensorCopy(input1, *gpu_place, context, &input1_gpu); - paddle::framework::TensorCopy(input2, *gpu_place, context, &input2_gpu); - paddle::framework::TensorCopy(input3, *gpu_place, context, &input3_gpu); + TensorCopy(input1, gpu_place, context, &input1_gpu); + TensorCopy(input2, gpu_place, context, &input2_gpu); + TensorCopy(input3, gpu_place, context, &input3_gpu); float* a = input1_gpu.data<float>(); float* b = input2_gpu.data<float>(); - float* c = input3_gpu.mutable_data<float>(*gpu_place); + float* c = input3_gpu.mutable_data<float>(gpu_place); paddle::operators::math::gemm<paddle::platform::CUDADeviceContext, float>( context, false, false, m, n, k, 1, a, 3, b + 1, 4, 1, c + 1, 4); - paddle::framework::TensorCopy(input3_gpu, *cpu_place, context, &input3); + TensorCopy(input3_gpu, cpu_place, context, &input3); // numpy code: // a = np.arange(6).reshape(2, 3) @@ -139,47 +231,105 @@ TEST(math_function, gemm_notrans_cublas) { EXPECT_EQ(input3_ptr[5], 73); EXPECT_EQ(input3_ptr[6], 86); EXPECT_EQ(input3_ptr[7], 99); - delete gpu_place; } -TEST(math_function, gemm_trans_cublas) { - paddle::framework::Tensor input1; - paddle::framework::Tensor input2; - paddle::framework::Tensor input3; - paddle::framework::Tensor input1_gpu; - paddle::framework::Tensor input2_gpu; - paddle::framework::Tensor input3_gpu; +TEST(math_function, gemm_notrans_cublas_fp16) { + using namespace paddle::framework; + using namespace paddle::platform; + + Tensor input1; + Tensor input2; + Tensor input3; + Tensor input1_gpu; + Tensor input2_gpu; + Tensor input3_gpu; + + CPUPlace cpu_place; + CUDAPlace gpu_place(0); + CUDADeviceContext context(gpu_place); + + int m = 2; + int n = 3; + int k = 3; + float16* input1_ptr = input1.mutable_data<float16>({2, 3}, cpu_place); + fill_fp16_data(input1_ptr, input1.numel(), {0, 1, 2, 3, 4, 5}); + float16* input2_ptr = input2.mutable_data<float16>({3, 4}, cpu_place); + fill_fp16_data(input2_ptr, input2.numel(), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); + float16* input3_ptr = input3.mutable_data<float16>({2, 4}, cpu_place); + fill_fp16_data(input3_ptr, input3.numel(), {0, 1, 2, 3, 4, 5, 6, 7}); + + TensorCopy(input1, gpu_place, context, &input1_gpu); + TensorCopy(input2, gpu_place, context, &input2_gpu); + TensorCopy(input3, gpu_place, context, &input3_gpu); + float16* a = input1_gpu.data<float16>(); + float16* b = input2_gpu.data<float16>(); + float16* c = input3_gpu.mutable_data<float16>(gpu_place); + + paddle::operators::math::gemm<paddle::platform::CUDADeviceContext, float16>( + context, false, false, m, n, k, float16(1), a, 3, b + 1, 4, float16(1), + c + 1, 4); + + TensorCopy(input3_gpu, cpu_place, context, &input3); + + // numpy code: + // a = np.arange(6).reshape(2, 3) + // b = np.arange(12).reshape(3, 4)[:, 1:] + // c = np.arange(8).reshape(2, 4)[:, 1:] + // out = np.arange(8).reshape(2, 4) + // out[:, 1:] = np.dot(a, b) + c + context.Wait(); + EXPECT_EQ(static_cast<float>(input3_ptr[0]), 0); + EXPECT_EQ(static_cast<float>(input3_ptr[1]), 24); + EXPECT_EQ(static_cast<float>(input3_ptr[2]), 28); + EXPECT_EQ(static_cast<float>(input3_ptr[3]), 32); + EXPECT_EQ(static_cast<float>(input3_ptr[4]), 4); + EXPECT_EQ(static_cast<float>(input3_ptr[5]), 73); + EXPECT_EQ(static_cast<float>(input3_ptr[6]), 86); + EXPECT_EQ(static_cast<float>(input3_ptr[7]), 99); +} + +TEST(math_function, gemm_trans_cublas_fp32) { + using namespace paddle::framework; + using namespace paddle::platform; + + Tensor input1; + Tensor input2; + Tensor input3; + Tensor input1_gpu; + Tensor input2_gpu; + Tensor input3_gpu; + + CPUPlace cpu_place; + CUDAPlace gpu_place(0); + CUDADeviceContext context(gpu_place); int m = 2; int n = 3; int k = 3; - auto* cpu_place = new paddle::platform::CPUPlace(); - float* input1_ptr = input1.mutable_data<float>({2, 3}, *cpu_place); + float* input1_ptr = input1.mutable_data<float>({2, 3}, cpu_place); float arr1[6] = {0, 1, 2, 3, 4, 5}; memcpy(input1_ptr, arr1, 6 * sizeof(float)); - float* input2_ptr = input2.mutable_data<float>({4, 3}, *cpu_place); + float* input2_ptr = input2.mutable_data<float>({4, 3}, cpu_place); float arr2[12] = {0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11}; memcpy(input2_ptr, arr2, 12 * sizeof(float)); - float* input3_ptr = input3.mutable_data<float>({2, 4}, *cpu_place); + float* input3_ptr = input3.mutable_data<float>({2, 4}, cpu_place); float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7}; memcpy(input3_ptr, arr3, 8 * sizeof(float)); - auto* gpu_place = new paddle::platform::CUDAPlace(0); - paddle::platform::CUDADeviceContext context(*gpu_place); - - paddle::framework::TensorCopy(input1, *gpu_place, context, &input1_gpu); - paddle::framework::TensorCopy(input2, *gpu_place, context, &input2_gpu); - paddle::framework::TensorCopy(input3, *gpu_place, context, &input3_gpu); + TensorCopy(input1, gpu_place, context, &input1_gpu); + TensorCopy(input2, gpu_place, context, &input2_gpu); + TensorCopy(input3, gpu_place, context, &input3_gpu); float* a = input1_gpu.data<float>(); float* b = input2_gpu.data<float>(); - float* c = input3_gpu.mutable_data<float>(*gpu_place); + float* c = input3_gpu.mutable_data<float>(gpu_place); paddle::operators::math::gemm<paddle::platform::CUDADeviceContext, float>( context, false, true, m, n, k, 1, a, 3, b + 3, 3, 1, c + 1, 4); - paddle::framework::TensorCopy(input3_gpu, *cpu_place, context, &input3); - context.Wait(); + TensorCopy(input3_gpu, cpu_place, context, &input3); + context.Wait(); EXPECT_EQ(input3_ptr[0], 0); EXPECT_EQ(input3_ptr[1], 24); EXPECT_EQ(input3_ptr[2], 28); @@ -188,27 +338,81 @@ TEST(math_function, gemm_trans_cublas) { EXPECT_EQ(input3_ptr[5], 73); EXPECT_EQ(input3_ptr[6], 86); EXPECT_EQ(input3_ptr[7], 99); - delete gpu_place; +} + +TEST(math_function, gemm_trans_cublas_fp16) { + using namespace paddle::framework; + using namespace paddle::platform; + + Tensor input1; + Tensor input2; + Tensor input3; + Tensor input1_gpu; + Tensor input2_gpu; + Tensor input3_gpu; + + CPUPlace cpu_place; + CUDAPlace gpu_place(0); + CUDADeviceContext context(gpu_place); + + int m = 2; + int n = 3; + int k = 3; + float16* input1_ptr = input1.mutable_data<float16>({2, 3}, cpu_place); + fill_fp16_data(input1_ptr, input1.numel(), {0, 1, 2, 3, 4, 5}); + float16* input2_ptr = input2.mutable_data<float16>({4, 3}, cpu_place); + fill_fp16_data(input2_ptr, input2.numel(), + {0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11}); + float16* input3_ptr = input3.mutable_data<float16>({2, 4}, cpu_place); + fill_fp16_data(input3_ptr, input3.numel(), {0, 1, 2, 3, 4, 5, 6, 7}); + + TensorCopy(input1, gpu_place, context, &input1_gpu); + TensorCopy(input2, gpu_place, context, &input2_gpu); + TensorCopy(input3, gpu_place, context, &input3_gpu); + float16* a = input1_gpu.data<float16>(); + float16* b = input2_gpu.data<float16>(); + float16* c = input3_gpu.mutable_data<float16>(gpu_place); + + paddle::operators::math::gemm<paddle::platform::CUDADeviceContext, float16>( + context, false, true, m, n, k, float16(1), a, 3, b + 3, 3, float16(1), + c + 1, 4); + + TensorCopy(input3_gpu, cpu_place, context, &input3); + + context.Wait(); + EXPECT_EQ(static_cast<float>(input3_ptr[0]), 0); + EXPECT_EQ(static_cast<float>(input3_ptr[1]), 24); + EXPECT_EQ(static_cast<float>(input3_ptr[2]), 28); + EXPECT_EQ(static_cast<float>(input3_ptr[3]), 32); + EXPECT_EQ(static_cast<float>(input3_ptr[4]), 4); + EXPECT_EQ(static_cast<float>(input3_ptr[5]), 73); + EXPECT_EQ(static_cast<float>(input3_ptr[6]), 86); + EXPECT_EQ(static_cast<float>(input3_ptr[7]), 99); } template <typename T> void GemvTest(int m, int n, bool trans) { - paddle::framework::Tensor mat_a; - paddle::framework::Tensor vec_b; - paddle::framework::Tensor vec_c; - auto* cpu_place = new paddle::platform::CPUPlace(); - - T* data_a = mat_a.mutable_data<T>({m, n}, *cpu_place); - T* data_b = vec_b.mutable_data<T>({trans ? m : n}, *cpu_place); - T* data_c = vec_c.mutable_data<T>({trans ? n : m}, *cpu_place); - - auto* gpu_place = new paddle::platform::CUDAPlace(0); - paddle::framework::Tensor g_mat_a; - paddle::framework::Tensor g_vec_b; - paddle::framework::Tensor g_vec_c; - T* g_data_a = g_mat_a.mutable_data<T>(mat_a.dims(), *gpu_place); - T* g_data_b = g_vec_b.mutable_data<T>(vec_b.dims(), *gpu_place); - T* g_data_c = g_vec_c.mutable_data<T>(vec_c.dims(), *gpu_place); + using namespace paddle::framework; + using namespace paddle::platform; + + Tensor mat_a; + Tensor vec_b; + Tensor vec_c; + + CPUPlace cpu_place; + CUDAPlace gpu_place(0); + CUDADeviceContext context(gpu_place); + + T* data_a = mat_a.mutable_data<T>({m, n}, cpu_place); + T* data_b = vec_b.mutable_data<T>({trans ? m : n}, cpu_place); + T* data_c = vec_c.mutable_data<T>({trans ? n : m}, cpu_place); + + Tensor g_mat_a; + Tensor g_vec_b; + Tensor g_vec_c; + T* g_data_a = g_mat_a.mutable_data<T>(mat_a.dims(), gpu_place); + T* g_data_b = g_vec_b.mutable_data<T>(vec_b.dims(), gpu_place); + T* g_data_c = g_vec_c.mutable_data<T>(vec_c.dims(), gpu_place); for (int i = 0; i < mat_a.numel(); ++i) { data_a[i] = static_cast<T>(i); @@ -217,16 +421,14 @@ void GemvTest(int m, int n, bool trans) { data_b[i] = static_cast<T>(i); } - paddle::platform::CUDADeviceContext context(*gpu_place); - paddle::framework::TensorCopy(mat_a, *gpu_place, context, &g_mat_a); - paddle::framework::TensorCopy(vec_b, *gpu_place, context, &g_vec_b); + TensorCopy(mat_a, gpu_place, context, &g_mat_a); + TensorCopy(vec_b, gpu_place, context, &g_vec_b); - paddle::operators::math::gemv<paddle::platform::CUDADeviceContext, T>( + paddle::operators::math::gemv<CUDADeviceContext, T>( context, trans, static_cast<int>(m), static_cast<int>(n), 1., g_data_a, g_data_b, 0., g_data_c); - paddle::framework::TensorCopy(g_vec_c, paddle::platform::CPUPlace(), context, - &vec_c); + TensorCopy(g_vec_c, cpu_place, context, &vec_c); if (!trans) { for (int i = 0; i < m; ++i) { diff --git a/paddle/fluid/platform/dynload/cublas.h b/paddle/fluid/platform/dynload/cublas.h index 580ed9bb57..fa9041134d 100644 --- a/paddle/fluid/platform/dynload/cublas.h +++ b/paddle/fluid/platform/dynload/cublas.h @@ -68,6 +68,8 @@ extern void *cublas_dso_handle; __macro(cublasDgemv_v2); \ __macro(cublasSgemm_v2); \ __macro(cublasDgemm_v2); \ + __macro(cublasHgemm); \ + __macro(cublasSgemmEx); \ __macro(cublasSgeam_v2); \ __macro(cublasDgeam_v2); \ __macro(cublasCreate_v2); \ @@ -83,6 +85,7 @@ extern void *cublas_dso_handle; __macro(cublasDgemmStridedBatched); \ __macro(cublasCgemmStridedBatched); \ __macro(cublasZgemmStridedBatched); \ + __macro(cublasHgemmStridedBatched); \ __macro(cublasSgetrfBatched); \ __macro(cublasSgetriBatched); \ __macro(cublasDgetrfBatched); \ From 4e517881f7e4d0ca8e3dac7234485fd6870418cc Mon Sep 17 00:00:00 2001 From: fengjiayi <fengjiayi@baidu.com> Date: Fri, 9 Mar 2018 14:45:48 +0800 Subject: [PATCH 39/40] remove HasNext --- doc/design/cpp_data_feeding.md | 3 +-- paddle/fluid/framework/reader.h | 4 ---- paddle/fluid/operators/read_op.cc | 11 ++++++----- .../fluid/operators/reader/create_batch_reader_op.cc | 8 ++++---- .../reader/create_random_data_generator_op.cc | 2 -- .../operators/reader/create_shuffle_reader_op.cc | 8 ++++---- 6 files changed, 15 insertions(+), 21 deletions(-) diff --git a/doc/design/cpp_data_feeding.md b/doc/design/cpp_data_feeding.md index 40205350f9..a122af8cb9 100644 --- a/doc/design/cpp_data_feeding.md +++ b/doc/design/cpp_data_feeding.md @@ -20,9 +20,8 @@ class ReaderBase { PADDLE_ENFORCE(!shapes_.empty()); } // Read the next batch of data. (A 'batch' can be only one instance) + // If the next batch doesn't exist, the 'out' will be an empty std::vector. virtual void ReadNext(std::vector<LoDTensor>* out) = 0; - // Show whether the next bacth exists. - virtual bool HasNext() const = 0; // Reinitialize the reader and read the file from the begin. virtual void ReInit() = 0; diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h index 27ab6e750c..1be3f4ef1f 100644 --- a/paddle/fluid/framework/reader.h +++ b/paddle/fluid/framework/reader.h @@ -26,7 +26,6 @@ class ReaderBase { PADDLE_ENFORCE(!shapes_.empty()); } virtual void ReadNext(std::vector<LoDTensor>* out) = 0; - virtual bool HasNext() const = 0; virtual void ReInit() = 0; @@ -52,8 +51,6 @@ class DecoratedReader : public ReaderBase { PADDLE_ENFORCE_NOT_NULL(reader_); } - bool HasNext() const override { return reader_->HasNext(); } - void ReInit() override { reader_->ReInit(); } protected: @@ -69,7 +66,6 @@ class ReaderHolder { ReaderBase* Get() const { return reader_.get(); } void ReadNext(std::vector<LoDTensor>* out) { reader_->ReadNext(out); } - bool HasNext() const { return reader_->HasNext(); } void ReInit() { reader_->ReInit(); } DDim shape(size_t idx) const { return reader_->shape(idx); } diff --git a/paddle/fluid/operators/read_op.cc b/paddle/fluid/operators/read_op.cc index 62beab82d4..2a5605e0d3 100644 --- a/paddle/fluid/operators/read_op.cc +++ b/paddle/fluid/operators/read_op.cc @@ -60,15 +60,16 @@ class ReadOp : public framework::OperatorBase { const platform::Place& dev_place) const override { framework::ReaderHolder* reader = scope.FindVar(Input("Reader"))->GetMutable<framework::ReaderHolder>(); - if (!reader->HasNext()) { + std::vector<std::string> out_arg_names = Outputs("Out"); + std::vector<framework::LoDTensor> ins; + reader->ReadNext(&ins); + if (ins.empty()) { reader->ReInit(); + reader->ReadNext(&ins); PADDLE_ENFORCE( - reader->HasNext(), + !ins.empty(), "Reader can not read the next data even it has been re-initialized."); } - std::vector<std::string> out_arg_names = Outputs("Out"); - std::vector<framework::LoDTensor> ins; - reader->ReadNext(&ins); PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size()); for (size_t i = 0; i < ins.size(); ++i) { auto* out = diff --git a/paddle/fluid/operators/reader/create_batch_reader_op.cc b/paddle/fluid/operators/reader/create_batch_reader_op.cc index bac043a552..9559159e82 100644 --- a/paddle/fluid/operators/reader/create_batch_reader_op.cc +++ b/paddle/fluid/operators/reader/create_batch_reader_op.cc @@ -68,10 +68,10 @@ void BatchReader::ReadNext(std::vector<framework::LoDTensor>* out) { buffer_.clear(); buffer_.reserve(batch_size_); for (int i = 0; i < batch_size_; ++i) { - if (reader_->HasNext()) { - buffer_.push_back(std::vector<framework::LoDTensor>()); - reader_->ReadNext(&buffer_.back()); - } else { + buffer_.push_back(std::vector<framework::LoDTensor>()); + reader_->ReadNext(&buffer_.back()); + if (buffer.back().empty()) { + buffer_.pop_back(); break; } } diff --git a/paddle/fluid/operators/reader/create_random_data_generator_op.cc b/paddle/fluid/operators/reader/create_random_data_generator_op.cc index f77ab8ab19..73c39b5da4 100644 --- a/paddle/fluid/operators/reader/create_random_data_generator_op.cc +++ b/paddle/fluid/operators/reader/create_random_data_generator_op.cc @@ -50,8 +50,6 @@ class RandomDataGenerator : public framework::FileReader { } } - bool HasNext() const override { return true; } - void ReInit() override { return; } private: diff --git a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc index 3e8b463efc..4dac383110 100644 --- a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc +++ b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc @@ -39,10 +39,10 @@ void ShuffleReader::ReadNext(std::vector<framework::LoDTensor>* out) { buffer_.clear(); buffer_.reserve(buffer_size_); for (int i = 0; i < buffer_size_; ++i) { - if (reader_->HasNext()) { - buffer_.push_back(std::vector<framework::LoDTensor>()); - reader_->ReadNext(&buffer_.back()); - } else { + buffer_.push_back(std::vector<framework::LoDTensor>()); + reader_->ReadNext(&buffer_.back()); + if (buffer_.back().empty()) { + buffer_.pop_back(); break; } } From 6e5736e2700decf5b991e0f84216fcea13983834 Mon Sep 17 00:00:00 2001 From: fengjiayi <fengjiayi@baidu.com> Date: Fri, 9 Mar 2018 15:07:20 +0800 Subject: [PATCH 40/40] fix a compile error --- doc/design/cpp_data_feeding.md | 2 +- paddle/fluid/operators/reader/create_batch_reader_op.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/design/cpp_data_feeding.md b/doc/design/cpp_data_feeding.md index a122af8cb9..22c2a925eb 100644 --- a/doc/design/cpp_data_feeding.md +++ b/doc/design/cpp_data_feeding.md @@ -20,7 +20,7 @@ class ReaderBase { PADDLE_ENFORCE(!shapes_.empty()); } // Read the next batch of data. (A 'batch' can be only one instance) - // If the next batch doesn't exist, the 'out' will be an empty std::vector. + // If the next batch doesn't exist, the '*out' will be an empty std::vector. virtual void ReadNext(std::vector<LoDTensor>* out) = 0; // Reinitialize the reader and read the file from the begin. diff --git a/paddle/fluid/operators/reader/create_batch_reader_op.cc b/paddle/fluid/operators/reader/create_batch_reader_op.cc index 9559159e82..277f2856c0 100644 --- a/paddle/fluid/operators/reader/create_batch_reader_op.cc +++ b/paddle/fluid/operators/reader/create_batch_reader_op.cc @@ -70,7 +70,7 @@ void BatchReader::ReadNext(std::vector<framework::LoDTensor>* out) { for (int i = 0; i < batch_size_; ++i) { buffer_.push_back(std::vector<framework::LoDTensor>()); reader_->ReadNext(&buffer_.back()); - if (buffer.back().empty()) { + if (buffer_.back().empty()) { buffer_.pop_back(); break; }