From 85a41df32d2793da5c1c49b9c36a3781567f4a7e Mon Sep 17 00:00:00 2001 From: yuyang18 Date: Wed, 16 May 2018 15:14:16 +0800 Subject: [PATCH 01/12] Init commit --- paddle/fluid/operators/random_crop_op.cc | 59 ++++++++ paddle/fluid/operators/random_crop_op.h | 167 +++++++++++++++++++++++ 2 files changed, 226 insertions(+) create mode 100644 paddle/fluid/operators/random_crop_op.cc create mode 100644 paddle/fluid/operators/random_crop_op.h diff --git a/paddle/fluid/operators/random_crop_op.cc b/paddle/fluid/operators/random_crop_op.cc new file mode 100644 index 0000000000..cb4bdde0ee --- /dev/null +++ b/paddle/fluid/operators/random_crop_op.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/operators/random_crop_op.h" +#include + +namespace paddle { +namespace operators { +class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", ""); + AddOutput("Y", ""); + AddInput("Seed", ""); + AddOutput("SeedOut", "").AsDispensable(); + AddAttr>("shape", ""); + } +}; + +class RandomCropOpInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* context) const override { + auto shape = context->Attrs().Get>("shape"); + auto x_dim = context->GetInputDim("X"); + PADDLE_ENFORCE_EQ(x_dim.size(), static_cast(shape.size())); + for (size_t i = 0; i < shape.size(); ++i) { + if (shape[i] == -1) { + shape[i] = static_cast(x_dim[i]); + } else { + PADDLE_ENFORCE_GE(x_dim[i], shape[i]); + } + } + context->SetOutputDim("Y", framework::make_ddim(shape)); + context->SetOutputDim("SeedOut", framework::make_ddim({1})); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace f = paddle::framework; +REGISTER_OPERATOR(random_crop, f::OperatorWithKernel, ops::RandomCropOpMaker, + ops::RandomCropOpInferShape); +template +using Kernel = ops::RandomCropKernel; + +REGISTER_OP_CPU_KERNEL(random_crop, Kernel, Kernel, Kernel, + Kernel, Kernel); diff --git a/paddle/fluid/operators/random_crop_op.h b/paddle/fluid/operators/random_crop_op.h new file mode 100644 index 0000000000..86a22227f3 --- /dev/null +++ b/paddle/fluid/operators/random_crop_op.h @@ -0,0 +1,167 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/detail/safe_ref.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/for_range.h" +#include "thrust/random.h" + +namespace paddle { +namespace operators { + +template +struct Random; + +template <> +struct Random { + using Engine = std::minstd_rand; + + template + using UniformIntDist = std::uniform_int_distribution; +}; + +template <> +struct Random { + using Engine = thrust::minstd_rand; + + template + using UniformIntDist = thrust::uniform_int_distribution; +}; + +template +HOSTDEVICE inline void RandomCropImpl(const T* x, size_t* x_dim, T* out, + size_t* out_dim, int i, int rank, + int64_t prod_x_remain, + int64_t prod_out_remain, size_t* offset) { + size_t x_length = x_dim[rank]; + size_t out_length = out_dim[rank]; + + int64_t x_stride = prod_x_remain / x_length; + int64_t out_stride = prod_out_remain / out_length; + size_t offset_i = offset[i]; + if (x_stride == 1 && out_stride == 1) { + // In the final stage, copy from offset. + x += offset_i; + for (size_t i = 0; i < out_length; ++i) { + *out++ = *x++; + } + } else { + x += offset_i * x_stride; + for (size_t i = 0; i < out_length; ++i) { + RandomCropImpl(x, x_dim, out, out_dim, i + 1, rank, x_stride, + out_stride, offset); + x += x_stride; + out += out_stride; + } + } +} + +template +struct RandomCropFunctor { + const T* x_; + T* out_; + size_t x_dim_[9]; + size_t out_dim_[9]; + size_t prod_same_dim_; + + size_t prod_x_dim_; + size_t prod_out_dim_; + + int num_same_dim_; + int rank_; + + int64_t seed_; + + RandomCropFunctor(const T* x, T* out, int64_t seed) + : x_(x), + out_(out), + prod_same_dim_(1), + prod_x_dim_(1), + prod_out_dim_(1), + seed_(seed) { + std::fill(x_dim_, x_dim_ + sizeof(x_dim_) / sizeof(size_t), 0); + std::fill(out_dim_, out_dim_ + sizeof(out_dim_) / sizeof(size_t), 0); + } + + HOSTDEVICE void operator()(size_t i) { + typename Random::Engine engine(seed_); + engine.discard(i * (rank_ - num_same_dim_)); + + int64_t prod_x_unsame = (prod_x_dim_ / prod_same_dim_); + int64_t prod_out_unsame = (prod_out_dim_ / prod_same_dim_); + + const T* x = x_ + i * prod_x_unsame; + T* out = out_ + i * prod_out_unsame; + + size_t offset[9]; + for (int i = num_same_dim_; i < rank_; ++i) { + typename Random::template UniformIntDist dist( + 0, x_dim_[i] - out_dim_[i]); + offset[i] = dist(engine); + } + RandomCropImpl(x, x_dim_, out, out_dim_, num_same_dim_, rank_, + prod_x_unsame, prod_out_unsame, offset); + } +}; + +template +class RandomCropKernel : public framework::OpKernel { + public: + virtual void Compute(const framework::ExecutionContext& context) const { + int64_t seed = + *context.Input("Seed")->data(); + auto& x = detail::Ref(context.Input("X")); + auto& out = detail::Ref(context.Output("Out")); + + RandomCropFunctor functor{ + x.data(), out.mutable_data(context.GetPlace()), seed}; + + auto& out_dim = out.dims(); + auto& x_dim = x.dims(); + + auto rank = x_dim.size(); + while (rank-- > 0) { + functor.x_dim_[rank] = x_dim[rank]; + functor.out_dim_[rank] = out_dim[rank]; + functor.prod_x_dim_ *= x_dim[rank]; + functor.prod_out_dim_ *= out_dim[rank]; + if (x_dim[rank] != out_dim[rank]) { + PADDLE_ENFORCE_EQ(functor.prod_same_dim_, 1); + functor.num_same_dim_ = rank; + } else { + functor.prod_same_dim_ *= out_dim[rank]; + } + } + functor.rank_ = x_dim.size(); + + platform::ForRange for_range( + context.template device_context(), + functor.prod_same_dim_); + + for_range(functor); + + Random::Engine engine(seed); + engine.discard(functor.prod_same_dim_ * + (functor.rank_ - functor.num_same_dim_)); + + *context.Output("SeedOut")->mutable_data( + platform::CPUPlace()) = engine(); + } +}; + +} // namespace operators +} // namespace paddle From 3e7ce5836f74339d78dacf59c6030c775174e04e Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 28 May 2018 13:42:55 +0800 Subject: [PATCH 02/12] stash --- paddle/fluid/operators/random_crop_op.cc | 48 +++-- paddle/fluid/operators/random_crop_op.h | 168 +++++++++--------- python/paddle/fluid/layers/nn.py | 126 ++++++------- .../tests/unittests/test_random_crop_op.py | 34 ++++ 4 files changed, 210 insertions(+), 166 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_random_crop_op.py diff --git a/paddle/fluid/operators/random_crop_op.cc b/paddle/fluid/operators/random_crop_op.cc index cb4bdde0ee..b9367f1d22 100644 --- a/paddle/fluid/operators/random_crop_op.cc +++ b/paddle/fluid/operators/random_crop_op.cc @@ -12,36 +12,52 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/operators/random_crop_op.h" -#include namespace paddle { namespace operators { + +class RandomCropOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; + class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", ""); - AddOutput("Y", ""); + AddOutput("Out", ""); AddInput("Seed", ""); AddOutput("SeedOut", "").AsDispensable(); AddAttr>("shape", ""); + AddComment(""); } }; class RandomCropOpInferShape : public framework::InferShapeBase { public: - void operator()(framework::InferShapeContext* context) const override { - auto shape = context->Attrs().Get>("shape"); - auto x_dim = context->GetInputDim("X"); - PADDLE_ENFORCE_EQ(x_dim.size(), static_cast(shape.size())); - for (size_t i = 0; i < shape.size(); ++i) { - if (shape[i] == -1) { - shape[i] = static_cast(x_dim[i]); - } else { - PADDLE_ENFORCE_GE(x_dim[i], shape[i]); - } + void operator()(framework::InferShapeContext* ctx) const override { + auto seed_dim = ctx->GetInputDim("Seed"); + PADDLE_ENFORCE(seed_dim.size() == 1 && seed_dim[0] == 1); + auto shape = ctx->Attrs().Get>("shape"); + auto x_dim = ctx->GetInputDim("X"); + PADDLE_ENFORCE_GT(x_dim.size(), static_cast(shape.size())); + auto out_dim = framework::vectorize2int(x_dim); + for (size_t i = 1; i <= shape.size(); ++i) { + size_t x_i = x_dim.size() - i; + size_t shape_i = shape.size() - i; + PADDLE_ENFORCE_GE(x_dim[x_i], shape[shape_i]); + out_dim[x_i] = shape[shape_i]; } - context->SetOutputDim("Y", framework::make_ddim(shape)); - context->SetOutputDim("SeedOut", framework::make_ddim({1})); + ctx->SetOutputDim("Out", framework::make_ddim(out_dim)); + ctx->SetOutputDim("SeedOut", framework::make_ddim({1})); } }; @@ -50,8 +66,8 @@ class RandomCropOpInferShape : public framework::InferShapeBase { namespace ops = paddle::operators; namespace f = paddle::framework; -REGISTER_OPERATOR(random_crop, f::OperatorWithKernel, ops::RandomCropOpMaker, - ops::RandomCropOpInferShape); +REGISTER_OPERATOR(random_crop, ops::RandomCropOp, ops::RandomCropOpMaker, + ops::RandomCropOpInferShape, f::EmptyGradOpMaker); template using Kernel = ops::RandomCropKernel; diff --git a/paddle/fluid/operators/random_crop_op.h b/paddle/fluid/operators/random_crop_op.h index 86a22227f3..8764bd0bc7 100644 --- a/paddle/fluid/operators/random_crop_op.h +++ b/paddle/fluid/operators/random_crop_op.h @@ -14,11 +14,14 @@ #pragma once +#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/for_range.h" -#include "thrust/random.h" +#ifdef PADDLE_WITH_CUDA +#include +#endif namespace paddle { namespace operators { @@ -34,6 +37,7 @@ struct Random { using UniformIntDist = std::uniform_int_distribution; }; +#ifdef PADDLE_WITH_CUDA template <> struct Random { using Engine = thrust::minstd_rand; @@ -41,29 +45,31 @@ struct Random { template using UniformIntDist = thrust::uniform_int_distribution; }; +#endif template -HOSTDEVICE inline void RandomCropImpl(const T* x, size_t* x_dim, T* out, - size_t* out_dim, int i, int rank, - int64_t prod_x_remain, - int64_t prod_out_remain, size_t* offset) { - size_t x_length = x_dim[rank]; - size_t out_length = out_dim[rank]; - - int64_t x_stride = prod_x_remain / x_length; - int64_t out_stride = prod_out_remain / out_length; - size_t offset_i = offset[i]; - if (x_stride == 1 && out_stride == 1) { - // In the final stage, copy from offset. +HOSTDEVICE inline void StridedMemcpy(const T* x, const size_t* x_dims, T* out, + const size_t* out_dims, int i, int rank, + size_t prod_x_remain, + size_t prod_out_remain, + const size_t* offsets) { + size_t x_dim_i = x_dims[i]; + size_t out_dim_i = out_dims[i]; + size_t x_stride = prod_x_remain / x_dim_i; + size_t out_stride = prod_out_remain / out_dim_i; + size_t offset_i = offsets[i]; + + if (i == rank - 1) { + PADDLE_ENFORCE(x_stride == 1 && out_stride == 1); x += offset_i; - for (size_t i = 0; i < out_length; ++i) { + for (size_t j = 0; j < out_dim_i; ++j) { *out++ = *x++; } } else { x += offset_i * x_stride; - for (size_t i = 0; i < out_length; ++i) { - RandomCropImpl(x, x_dim, out, out_dim, i + 1, rank, x_stride, - out_stride, offset); + for (size_t j = 0; j < x_dim_i; ++j) { + StridedMemcpy(x, x_dims, out, out_dims, i + 1, rank, x_stride, + out_stride, offsets); x += x_stride; out += out_stride; } @@ -74,94 +80,96 @@ template struct RandomCropFunctor { const T* x_; T* out_; - size_t x_dim_[9]; - size_t out_dim_[9]; - size_t prod_same_dim_; - - size_t prod_x_dim_; - size_t prod_out_dim_; - - int num_same_dim_; + size_t x_dims_[9]; + size_t out_dims_[9]; + int num_batchsize_dims_; int rank_; - int64_t seed_; - RandomCropFunctor(const T* x, T* out, int64_t seed) + size_t prod_x_dims_; + size_t prod_out_dims_; + size_t prod_batchsize_dims_; + size_t prod_x_ins_dims_; + size_t prod_out_ins_dims_; + + RandomCropFunctor(const T* x, T* out, const framework::DDim& x_dims, + const framework::DDim& out_dims, int num_batchsize_dims, + int64_t seed) : x_(x), out_(out), - prod_same_dim_(1), - prod_x_dim_(1), - prod_out_dim_(1), + num_batchsize_dims_(num_batchsize_dims), + rank_(x_dims.size()), seed_(seed) { - std::fill(x_dim_, x_dim_ + sizeof(x_dim_) / sizeof(size_t), 0); - std::fill(out_dim_, out_dim_ + sizeof(out_dim_) / sizeof(size_t), 0); + PADDLE_ENFORCE_EQ(x_dims.size(), out_dims.size()); + PADDLE_ENFORCE_GT(rank_, num_batchsize_dims_); + prod_batchsize_dims_ = 1; + prod_x_ins_dims_ = 1; + prod_out_ins_dims_ = 1; + for (size_t i = 0; i < rank_; ++i) { + size_t x_dim_i = x_dims[i]; + size_t out_dim_i = out_dims[i]; + x_dims_[i] = x_dim_i; + out_dims_[i] = out_dim_i; + if (i < num_batchsize_dims_) { + PADDLE_ENFORCE_EQ(x_dim_i, out_dim_i); + prod_batchsize_dims_ *= x_dim_i; + } else { + prod_x_ins_dims_ *= x_dim_i; + prod_out_ins_dims_ *= out_dim_i; + } + } + prod_x_dims_ = prod_batchsize_dims_ * prod_x_ins_dims_; + prod_out_dims_ = prod_batchsize_dims_ * prod_out_ins_dims_; } - HOSTDEVICE void operator()(size_t i) { + HOSTDEVICE void operator()(size_t ins_idx) { typename Random::Engine engine(seed_); - engine.discard(i * (rank_ - num_same_dim_)); - - int64_t prod_x_unsame = (prod_x_dim_ / prod_same_dim_); - int64_t prod_out_unsame = (prod_out_dim_ / prod_same_dim_); - - const T* x = x_ + i * prod_x_unsame; - T* out = out_ + i * prod_out_unsame; - - size_t offset[9]; - for (int i = num_same_dim_; i < rank_; ++i) { + engine.discard(ins_idx * (rank_ - num_batchsize_dims_)); + size_t offsets[9]; + for (int i = num_batchsize_dims_; i < rank_; ++i) { typename Random::template UniformIntDist dist( - 0, x_dim_[i] - out_dim_[i]); - offset[i] = dist(engine); + 0, x_dims_[i] - out_dims_[i]); + offsets[i] = dist(engine); } - RandomCropImpl(x, x_dim_, out, out_dim_, num_same_dim_, rank_, - prod_x_unsame, prod_out_unsame, offset); + + const T* x = x_ + ins_idx * prod_x_ins_dims_; + T* out = out_ + ins_idx * prod_out_ins_dims_; + + StridedMemcpy(x, x_dims_ + num_batchsize_dims_, out, + out_dims_ + num_batchsize_dims_, 0, + rank_ - num_batchsize_dims_, prod_x_ins_dims_, + prod_out_ins_dims_, offsets); } }; template class RandomCropKernel : public framework::OpKernel { public: - virtual void Compute(const framework::ExecutionContext& context) const { - int64_t seed = - *context.Input("Seed")->data(); - auto& x = detail::Ref(context.Input("X")); - auto& out = detail::Ref(context.Output("Out")); - - RandomCropFunctor functor{ - x.data(), out.mutable_data(context.GetPlace()), seed}; - - auto& out_dim = out.dims(); - auto& x_dim = x.dims(); - - auto rank = x_dim.size(); - while (rank-- > 0) { - functor.x_dim_[rank] = x_dim[rank]; - functor.out_dim_[rank] = out_dim[rank]; - functor.prod_x_dim_ *= x_dim[rank]; - functor.prod_out_dim_ *= out_dim[rank]; - if (x_dim[rank] != out_dim[rank]) { - PADDLE_ENFORCE_EQ(functor.prod_same_dim_, 1); - functor.num_same_dim_ = rank; - } else { - functor.prod_same_dim_ *= out_dim[rank]; - } - } - functor.rank_ = x_dim.size(); - + virtual void Compute(const framework::ExecutionContext& ctx) const { + int64_t seed = *ctx.Input("Seed")->data(); + auto shape = ctx.Attr>("shape"); + auto& x = detail::Ref(ctx.Input("X")); + auto& out = detail::Ref(ctx.Output("Out")); + + int num_batchsize_dims = x.dims().size() - shape.size(); + RandomCropFunctor functor( + x.data(), out.mutable_data(ctx.GetPlace()), x.dims(), out.dims(), + num_batchsize_dims, seed); platform::ForRange for_range( - context.template device_context(), - functor.prod_same_dim_); + ctx.template device_context(), + functor.prod_batchsize_dims_); for_range(functor); Random::Engine engine(seed); - engine.discard(functor.prod_same_dim_ * - (functor.rank_ - functor.num_same_dim_)); - - *context.Output("SeedOut")->mutable_data( + engine.discard(functor.prod_batchsize_dims_ * + (functor.rank_ - functor.num_batchsize_dims_)); + *ctx.Output("SeedOut")->mutable_data( platform::CPUPlace()) = engine(); } }; +// TODO(fengjiayi): Backward of random crop op + } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 04ee8ac9ae..42e26dd366 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -24,64 +24,19 @@ from tensor import concat import utils __all__ = [ - 'fc', - 'embedding', - 'dynamic_lstm', - 'dynamic_lstmp', - 'dynamic_gru', - 'gru_unit', - 'linear_chain_crf', - 'crf_decoding', - 'cos_sim', - 'cross_entropy', - 'square_error_cost', - 'chunk_eval', - 'sequence_conv', - 'conv2d', - 'sequence_pool', - 'sequence_softmax', - 'softmax', - 'pool2d', - 'batch_norm', - 'beam_search_decode', - 'conv2d_transpose', - 'sequence_expand', - 'lstm_unit', - 'reduce_sum', - 'reduce_mean', - 'reduce_max', - 'reduce_min', - 'reduce_prod', - 'sequence_first_step', - 'sequence_last_step', - 'dropout', - 'split', - 'ctc_greedy_decoder', - 'edit_distance', - 'l2_normalize', - 'matmul', - 'topk', - 'warpctc', - 'sequence_reshape', - 'transpose', - 'im2sequence', - 'nce', - 'beam_search', - 'row_conv', - 'multiplex', - 'layer_norm', - 'softmax_with_cross_entropy', - 'smooth_l1', - 'one_hot', - 'autoincreased_step_counter', - 'reshape', - 'lod_reset', - 'lrn', - 'pad', - 'label_smooth', - 'roi_pool', - 'dice_loss', - 'bilinear_interp', + 'fc', 'embedding', 'dynamic_lstm', 'dynamic_lstmp', 'dynamic_gru', + 'gru_unit', 'linear_chain_crf', 'crf_decoding', 'cos_sim', 'cross_entropy', + 'square_error_cost', 'chunk_eval', 'sequence_conv', 'conv2d', + 'sequence_pool', 'sequence_softmax', 'softmax', 'pool2d', 'batch_norm', + 'beam_search_decode', 'conv2d_transpose', 'sequence_expand', 'lstm_unit', + 'reduce_sum', 'reduce_mean', 'reduce_max', 'reduce_min', 'reduce_prod', + 'sequence_first_step', 'sequence_last_step', 'dropout', 'split', + 'ctc_greedy_decoder', 'edit_distance', 'l2_normalize', 'matmul', 'topk', + 'warpctc', 'sequence_reshape', 'transpose', 'im2sequence', 'nce', + 'beam_search', 'row_conv', 'multiplex', 'layer_norm', + 'softmax_with_cross_entropy', 'smooth_l1', 'one_hot', + 'autoincreased_step_counter', 'reshape', 'lod_reset', 'lrn', 'pad', + 'label_smooth', 'roi_pool', 'dice_loss', 'bilinear_interp', 'random_crop' ] @@ -154,7 +109,8 @@ def fc(input, Examples: .. code-block:: python - data = fluid.layers.data(name="data", shape=[32, 32], dtype="float32") + data = fluid.layers.data( + name="data", shape=[32, 32], dtype="float32") fc = fluid.layers.fc(input=data, size=1000, act="tanh") """ @@ -349,7 +305,8 @@ def dynamic_lstm(input, cell_activation(str): The activation for cell output. Choices = ["sigmoid", "tanh", "relu", "identity"], default "tanh". candidate_activation(str): The activation for candidate hidden state. - Choices = ["sigmoid", "tanh", "relu", "identity"], + Choices = ["sigmoid", "tanh", + "relu", "identity"], default "tanh". dtype(str): Data type. Choices = ["float32", "float64"], default "float32". name(str|None): A name for this layer(optional). If set None, the layer @@ -516,10 +473,12 @@ def dynamic_lstmp(input, cell_activation(str): The activation for cell output. Choices = ["sigmoid", "tanh", "relu", "identity"], default "tanh". candidate_activation(str): The activation for candidate hidden state. - Choices = ["sigmoid", "tanh", "relu", "identity"], + Choices = ["sigmoid", "tanh", + "relu", "identity"], default "tanh". proj_activation(str): The activation for projection output. - Choices = ["sigmoid", "tanh", "relu", "identity"], + Choices = ["sigmoid", "tanh", + "relu", "identity"], default "tanh". dtype(str): Data type. Choices = ["float32", "float64"], default "float32". name(str|None): A name for this layer(optional). If set None, the layer @@ -2171,7 +2130,8 @@ def reduce_mean(input, dim=None, keep_dim=False, name=None): fluid.layers.reduce_mean(x) # [0.4375] fluid.layers.reduce_mean(x, dim=0) # [0.15, 0.25, 0.55, 0.8] fluid.layers.reduce_mean(x, dim=-1) # [0.475, 0.4] - fluid.layers.reduce_mean(x, dim=1, keep_dim=True) # [[0.475], [0.4]] + fluid.layers.reduce_mean( + x, dim=1, keep_dim=True) # [[0.475], [0.4]] # x is a Tensor variable with shape [2, 2, 2] and elements as below: # [[[1.0, 2.0], [3.0, 4.0]], @@ -2390,7 +2350,8 @@ def split(input, num_or_sections, dim=-1, name=None): x0.shape # [3, 3, 5] x1.shape # [3, 3, 5] x2.shape # [3, 3, 5] - x0, x1, x2 = fluid.layers.split(x, num_or_sections=[2, 3, 4], dim=1) + x0, x1, x2 = fluid.layers.split( + x, num_or_sections=[2, 3, 4], dim=1) x0.shape # [3, 2, 5] x1.shape # [3, 3, 5] x2.shape # [3, 4, 5] @@ -3300,7 +3261,8 @@ def softmax_with_cross_entropy(logits, label, soft_label=False): data = fluid.layers.data(name='data', shape=[128], dtype='float32') label = fluid.layers.data(name='label', shape=[1], dtype='int64') fc = fluid.layers.fc(input=data, size=100) - out = fluid.layers.softmax_with_cross_entropy(logits=fc, label=label) + out = fluid.layers.softmax_with_cross_entropy( + logits=fc, label=label) """ helper = LayerHelper('softmax_with_cross_entropy', **locals()) softmax = helper.create_tmp_variable(dtype=logits.dtype) @@ -3347,7 +3309,8 @@ def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None): .. code-block:: python data = fluid.layers.data(name='data', shape=[128], dtype='float32') - label = fluid.layers.data(name='label', shape=[100], dtype='float32') + label = fluid.layers.data( + name='label', shape=[100], dtype='float32') fc = fluid.layers.fc(input=data, size=100) out = fluid.layers.smooth_l1(x=fc, y=label) """ @@ -3669,7 +3632,8 @@ def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None): Examples: .. code-block:: python - data = fluid.layers.data(name="data", shape=[3, 112, 112], dtype="float32") + data = fluid.layers.data( + name="data", shape=[3, 112, 112], dtype="float32") lrn = fluid.layers.lrn(input=data) """ helper = LayerHelper('lrn', **locals()) @@ -3922,10 +3886,10 @@ def bilinear_interp(input, out_h, out_w, name=None): Bilinear interpolation is an extension of linear interpolation for interpolating functions of two variables (e.g. H-direction and W-direction in this layer) on a rectilinear 2D grid. - + For details, please refer to Wikipedia: https://en.wikipedia.org/wiki/Bilinear_interpolation - + Args: input (Variable): The input tensor of bilinear interpolation, This is a 4-D tensor of the shape @@ -3938,7 +3902,7 @@ def bilinear_interp(input, out_h, out_w, name=None): Returns: out (Variable): The output is a 4-D tensor of the shape (num_batches, channls, out_h, out_w). - + Examples: .. code-block:: python @@ -3954,3 +3918,25 @@ def bilinear_interp(input, out_h, out_w, name=None): attrs={"out_h": out_h, "out_w": out_w}) return out + + +def random_crop(input, shape, seed=0): + helper = LayerHelper("random_crop", **locals()) + dtype = helper.input_dtype() + out = helper.create_tmp_variable(dtype) + if isinstance(seed, int): + seed = helper.create_global_variable( + persistable=True, shape=[1], dtype="int32") + helper.set_variable_initializer( + var=seed, initializer=Constant(value=seed)) + elif not isinstance(seed, Variable): + raise ValueError("'seed' must be a Variable or an int.") + seed_out = helper.create_tmp_variable(dtype="int32") + helper.append_op( + type="random_crop", + inputs={"X": input, + "Seed": seed}, + outputs={"Out": out, + "SeedOut": seed_out}, + attrs={"shape": shape}) + return out diff --git a/python/paddle/fluid/tests/unittests/test_random_crop_op.py b/python/paddle/fluid/tests/unittests/test_random_crop_op.py new file mode 100644 index 0000000000..e609e2c99f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_random_crop_op.py @@ -0,0 +1,34 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle.fluid.core as core +from op_test import OpTest + + +class TestRandomCropOp(OpTest): + def setUp(self): + to_crop = np.random.random((1, 10, 15)).astype("float32") + self.op_type = "random_crop" + self.inputs = {'X': to_crop, 'Seed': np.array([10])} + self.outputs = {'Out': np.array([1, 2, 3]), 'SeedOut': np.array([2])} + self.attrs = {'shape': [5, 5]} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() From d44dbc4a5217fe1b3721824d83351b776f7d64c5 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 28 May 2018 15:42:20 +0800 Subject: [PATCH 03/12] fix errors --- paddle/fluid/operators/random_crop_op.h | 2 +- .../paddle/fluid/tests/unittests/op_test.py | 2 ++ .../tests/unittests/test_random_crop_op.py | 20 +++++++++++++++---- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/random_crop_op.h b/paddle/fluid/operators/random_crop_op.h index 8764bd0bc7..a34294f5ee 100644 --- a/paddle/fluid/operators/random_crop_op.h +++ b/paddle/fluid/operators/random_crop_op.h @@ -129,7 +129,7 @@ struct RandomCropFunctor { for (int i = num_batchsize_dims_; i < rank_; ++i) { typename Random::template UniformIntDist dist( 0, x_dims_[i] - out_dims_[i]); - offsets[i] = dist(engine); + offsets[i - num_batchsize_dims_] = dist(engine); } const T* x = x_ + ins_idx * prod_x_ins_dims_; diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 709b4bf2fc..9f9ee271f8 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -336,6 +336,8 @@ class OpTest(unittest.TestCase): actual_t = np.array(actual) expect = self.outputs[out_name] expect_t = expect[0] if isinstance(expect, tuple) else expect + import pdb + pdb.set_trace() self.assertTrue( np.allclose( actual_t, expect_t, atol=atol), diff --git a/python/paddle/fluid/tests/unittests/test_random_crop_op.py b/python/paddle/fluid/tests/unittests/test_random_crop_op.py index e609e2c99f..1c708d0386 100644 --- a/python/paddle/fluid/tests/unittests/test_random_crop_op.py +++ b/python/paddle/fluid/tests/unittests/test_random_crop_op.py @@ -20,14 +20,26 @@ from op_test import OpTest class TestRandomCropOp(OpTest): def setUp(self): - to_crop = np.random.random((1, 10, 15)).astype("float32") + to_crop = np.array([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]] * + 5).astype("float32") + self.possible_res = [ + np.array([[1, 2, 3], [5, 6, 7]]), np.array([[2, 3, 4], [6, 7, 8]]), + np.array([[5, 6, 7], [9, 10, 11]]), + np.array([[6, 7, 8], [10, 11, 12]]) + ] self.op_type = "random_crop" self.inputs = {'X': to_crop, 'Seed': np.array([10])} - self.outputs = {'Out': np.array([1, 2, 3]), 'SeedOut': np.array([2])} - self.attrs = {'shape': [5, 5]} + self.outputs = {'Out': np.array([]), 'SeedOut': np.array([])} + self.attrs = {'shape': [2, 3]} def test_check_output(self): - self.check_output() + self.check_output_customized(self.verify_output) + + def verify_output(self, outs): + out = np.array(outs[1]) + for ins in out[:]: + is_equal = [(ins == res).all() for res in self.possible_res] + self.assertIn(True, is_equal) if __name__ == "__main__": From 20c8ff0f5f85b372ca92d7c81558fbe2a187d1fd Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 28 May 2018 15:58:46 +0800 Subject: [PATCH 04/12] Add comments --- paddle/fluid/operators/random_crop_op.cc | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/random_crop_op.cc b/paddle/fluid/operators/random_crop_op.cc index b9367f1d22..d92b8bbbb5 100644 --- a/paddle/fluid/operators/random_crop_op.cc +++ b/paddle/fluid/operators/random_crop_op.cc @@ -32,12 +32,18 @@ class RandomCropOp : public framework::OperatorWithKernel { class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("X", ""); - AddOutput("Out", ""); - AddInput("Seed", ""); - AddOutput("SeedOut", "").AsDispensable(); - AddAttr>("shape", ""); - AddComment(""); + AddInput("X", "A batch of instances to random crop."); + AddInput("Seed", "The random seed."); + AddOutput("Out", "The cropped instance batch."); + AddOutput("SeedOut", "The random seed after random cropping.") + .AsDispensable(); + AddAttr>("shape", "The shape of a cropped instance."); + AddComment(R"DOC( + This operator takes a batch of instance, and do random cropping on each instance. + It means that cropping positions differs on each instance, which is determined + by an uniform random generator. All cropped instances have the same shape, which + is determined by the operator's attribute 'shape'. + )DOC"); } }; From 291f7f8ce5d9b1a7b4d0d4840bee946d9a6d64c5 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 28 May 2018 16:37:53 +0800 Subject: [PATCH 05/12] fix a error --- python/paddle/fluid/tests/unittests/op_test.py | 2 -- tools/codestyle/docstring_checker.pyc | Bin 0 -> 11769 bytes 2 files changed, 2 deletions(-) create mode 100644 tools/codestyle/docstring_checker.pyc diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index b7e62533b3..b611470fa1 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -336,8 +336,6 @@ class OpTest(unittest.TestCase): actual_t = np.array(actual) expect = self.outputs[out_name] expect_t = expect[0] if isinstance(expect, tuple) else expect - import pdb - pdb.set_trace() self.assertTrue( np.allclose( actual_t, expect_t, atol=atol), diff --git a/tools/codestyle/docstring_checker.pyc b/tools/codestyle/docstring_checker.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49f07e8e6c0b4d4c7a253c2763e5fb5e7c8c668c GIT binary patch literal 11769 zcmdT~Npl;=7488*kbt;or54N9SSu6>B(->BXxWq~%SvoYHISncEfx%6hUAa~3^+Zo zNn5H^N~tWz+>(Egs#N8a%FUHSeDg8+0m(6^98x(YuH^e(4+bDbJ5t${9Fo)Qe!ahb zdw2F9L&eokcf4O$$xjyl-^8bXg~U~A4ymQwj0!9@hm6T*)LcgLSrueeBd6was?n$B z`jo{MIThqpqhHPSn{uBD22`V<<_e~a(TZxWsNB5zF{ap|+sfwcXa9YsaBoC&+HC$4g<;cEh@D zTAtF#tOwqh{~!F!mNFP7=75)L+DnFs`%rGjexOSht4EMrpGN=CcatWjg8f`?q1Fy! z*RRKQw&q7XUcsmLBe}2CG6u1jd1&#EXsXD}dN&$4-lDH#tX`JvMS3h*)=pGLNo6&u zj0xPlZ?l>_Y68EBkxIxZN%T$R#^2Kzb^Kn1Wn_>oIzY?0b@Yv00!+W>=5#rSX7y~fHYlOoiP5e@-jYs2Z@Ix z$4Mg<8X^T`d8*9hl#6^${TS5Nr`$dP<~xMaPeNg{frL;%P#_XQipr^!ikxB+NxA~N z9@Mn9m%_l+cD&@-ty)Ctz@8Fnu)V+&WQ-MU{GSWoNxd7*h0|#*kDb)k77MU|*t@eV2r*T5ns^7#BSC4T{POPjEE8QX& zdxbUANESO}3d)@77BansY;Pf_rhMsBmBXx*$TG9C`RQm;3x!-=t2J5y=%^ocju*G1 zrjALZPR-Yz76x#5YSOKSZh|9c5TI?wk3-Hqg->&SWep|2f+nE^qKA7*DS>DLb{=F& z5BD?V;|DplR#Nvt7&&DFw31QLR8l_j0AQi)SI}j3pSq1wrNm)pfjOmo(uW}m0t2rp zIF#KBmJHiT)^EDtSfQoB`_K!7(b_y*?SJfTGp=N^MK7+()MHqy<_Re)Zw*>|tfDoX zsS`^)Wm1_C)p_jB2)>7C2^1|u@+tKIOg}^%uJu_!*v$MFL?Y0)41}V(PsLevzmLd; zC|I#n^sZVxrkZ9jX-T~2*=89cBFC1|$CMTK^J*c(q0v66%}ebsDL1yL9gy1JQ-*9& zTa?;AD|g4nT1<^o9aINbaw^JL+@k?y)xm?ym`E0HnMJ$9T+mmvxy?E(1ULekDdLFfe%RFZLKHSAyF^rxNGt97l3mj_Iqs~Ta2PVeUT5tN|Jge2!IqP0 ziDOU}mJF$#o}~VyoU9pLXE&h+dZeWhlI)etwQ;ChQQNcq1?hl+5@D;`^}46^LOTf7 zy8{SWhnOqUqKJhhG#O5o=nNAPR5P8K9O^wSl0>#Q>Fgqa8mF_*Nr&h&d_{tE4zfr! zZ$wSENTwGN5R1}i1%3?02}#8T5E(6J51Sl7Vw%LYDAphO@sbf(lw!43%WJwZS6YSe zsWz3qTFP7KFNZ2vLp<6zVY3WyoCETw;PPLk$>+NdKwTVTC`Dq&bLCPBUxd z2Nitj{ps>tt^Dys`L$|!{>4lz`lMweIyivga>77FWOfY; zuD&GfyBLM-mP5l`w7F5(ROZIyYiu+7Ixk=c)oQcW@Tygj zT-9nLblU-olt$Gmq)@f$3?VNPP-L8Qg2j_eh%<*+>*FSyOeucfV-9Z9m(ub?itDKRqk2I>{D)8GNW{p#)-VENM=H9^nT*fQu(*m z=|FZIQMl1!D~Rh!wQkxc4zyHh5)Hs*y4~flS#E{er@2xK$_ruCsKsUIM&ocE+thT$ zTZjAUKGKe?g~`fhwsOZSCb@UPqSv(z*o~%VQ+soP9k&iH|0VmJfVJ2NSejQB9e)m8 z;19d@M@wGQuGxAijN;zDKm)or;M4o?Ud$N6__tr1m^i1&AT!;W+|bjolVTgTfK3y? zK00>PZh4U(x@Klnlv$xS&)l=#M@Ok8W27^m*RNsb=^(IyP;I7y{p3eBZ4QkE>)5{M z*>q=te{9h2jeBU?DYA}oEeLohTA$y53T^^b9NwH4auZe& z3^ACHM;P@KXi?J543 zIm<{|368d!{>#TmXMa_kw=CAr?+QNsBP7P0qb5eY%J^?gGwJ~3>k6E%Q#5KTgyodJ zCm|y^<>=3leCR`@R2SM!@q-X= zGxAB>ZzZ7APf2LR1o%=nXZjJRNBo*n;^mC71m6evvt=q$GhtnCB_V>X-JPz3V~axr z)&!xwRSR^_2tK9Pti-4!f+s>vbg@~pDDF6Yb`7kv-M@71eso)bWu$cv*xT>Uj zVS%NZt{P1$q-_LxRk-`YUtyESXQZeo=SOd#a zsg(hxFDXb%1A&C3Oogv8FX)BY3O7iw=fc-PFQPGHHG#VRj)x-x#kq(Hh18RJWJvj{vAu|rOb|UO6?)sIGK85H-Cd_iIKQu z=n~BKM2968Qk|M`Vbi?dfHK3DmGcr4idp9^B(o)k;*6HM#+)cEUJe<0Bd!AKx|Z`c zfkK?uk?RJ@%wYRiYcqL~$uT4%4C?I&dXDB@37~Bh>kgfxPw^$%X+-qWU}nfV49#@d z8ZhB*3U+QP%a-d%3|)dke+ECLDAajSs1tS)Kna&4YCLYBIrUR$dGLNl{RBOBs2Mb2 ztGrw!yk{sIaujiHIIRPmfU#=G0|vG;7KMS(A@z{{Ba)8v?9iYI#wJ!#30DmElznmH zG(A1IIA|^2P1nwO0!v9usE)bR9pQ)_-*vwm;RsvKqxa_!#fNj~k za|KiW317lJ1%#AyncY^-I*R{?@ja?Gq1ZZ79kicHb!^J;6g-TFy$p1|N9cn)PzR0x z8-X13vWRc+3#dYrC$EOUGQ8~p%ixklScV!!{{czY#^_<14rfSMpyR*Wr^Zg%T}ppU zUV)2l+!OHJcCd@kJCiYo?tCv1(TgyWbR)4iB9)&*t#}>SpSjt1o9(-Dc`MpMi|t_< zX9gqxjid9PNA*+aK8NoB1TIL1{6GtZ0+$4)e~2U&mA^J2=W9XJpP<^u@j~r`c1SfD z-fyh#{mt&=5j|o6N9+p&>XP>?@IGP4(a7 zVm%8Yq8R{d!!D&!5?BIh-VX0A;|8%yxL;pHW6R3n0fKsJqbV8po{Q`MyyS_rZi>(MVN^^r!?5T3^$Y5Z1N71vq+32LuQ+(@pX*+Z;s9z zKvl>?@$+6RZ}s8(KcLEgVb#5nfF80;?b0-yg-B=(p+Cf~OxysW!!EUaBVz8^Q0U7* z!d4XOIIO5JxrIdTcMMbXP~%Bfe@m!w3TqjKjQ`)%NVUdKWaMsiSu`>dw{6gY(37w* z5x$>aAIc8zn*%)cQI zKTr$TuU|!<({Z~7)A;d5S^XTV#u~+j(*FdKHigf6cmzSUo)Gxvi29Rya9?c#JQi=H zQ+{pXE!Loq-78J0>w5Rvi<8My4QyLa+Rg>QyL=;d9AcWd*(mX*;WW&d7|x|#qT3<& zpc0+o^>2O258TLW_AnQXkZIfyEH-CxDJjZQ z39kv92`03z%-fNw&sTDCv4Bfjy1zGXRIcB;eq(m}o$9++E0yVMSF6`&Bx?5kt9NeS zaAtbI&K(X!yGXA$?uCJvbS>OIMSi?4=x%7d$D>dPTk?*`c>{UpO(xS!7}awwGoj-v z_PbbP4waxo+2e5MU6m`j?QtGx Date: Mon, 28 May 2018 16:42:17 +0800 Subject: [PATCH 06/12] delete tmp file --- tools/codestyle/docstring_checker.pyc | Bin 11769 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 tools/codestyle/docstring_checker.pyc diff --git a/tools/codestyle/docstring_checker.pyc b/tools/codestyle/docstring_checker.pyc deleted file mode 100644 index 49f07e8e6c0b4d4c7a253c2763e5fb5e7c8c668c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 11769 zcmdT~Npl;=7488*kbt;or54N9SSu6>B(->BXxWq~%SvoYHISncEfx%6hUAa~3^+Zo zNn5H^N~tWz+>(Egs#N8a%FUHSeDg8+0m(6^98x(YuH^e(4+bDbJ5t${9Fo)Qe!ahb zdw2F9L&eokcf4O$$xjyl-^8bXg~U~A4ymQwj0!9@hm6T*)LcgLSrueeBd6was?n$B z`jo{MIThqpqhHPSn{uBD22`V<<_e~a(TZxWsNB5zF{ap|+sfwcXa9YsaBoC&+HC$4g<;cEh@D zTAtF#tOwqh{~!F!mNFP7=75)L+DnFs`%rGjexOSht4EMrpGN=CcatWjg8f`?q1Fy! z*RRKQw&q7XUcsmLBe}2CG6u1jd1&#EXsXD}dN&$4-lDH#tX`JvMS3h*)=pGLNo6&u zj0xPlZ?l>_Y68EBkxIxZN%T$R#^2Kzb^Kn1Wn_>oIzY?0b@Yv00!+W>=5#rSX7y~fHYlOoiP5e@-jYs2Z@Ix z$4Mg<8X^T`d8*9hl#6^${TS5Nr`$dP<~xMaPeNg{frL;%P#_XQipr^!ikxB+NxA~N z9@Mn9m%_l+cD&@-ty)Ctz@8Fnu)V+&WQ-MU{GSWoNxd7*h0|#*kDb)k77MU|*t@eV2r*T5ns^7#BSC4T{POPjEE8QX& zdxbUANESO}3d)@77BansY;Pf_rhMsBmBXx*$TG9C`RQm;3x!-=t2J5y=%^ocju*G1 zrjALZPR-Yz76x#5YSOKSZh|9c5TI?wk3-Hqg->&SWep|2f+nE^qKA7*DS>DLb{=F& z5BD?V;|DplR#Nvt7&&DFw31QLR8l_j0AQi)SI}j3pSq1wrNm)pfjOmo(uW}m0t2rp zIF#KBmJHiT)^EDtSfQoB`_K!7(b_y*?SJfTGp=N^MK7+()MHqy<_Re)Zw*>|tfDoX zsS`^)Wm1_C)p_jB2)>7C2^1|u@+tKIOg}^%uJu_!*v$MFL?Y0)41}V(PsLevzmLd; zC|I#n^sZVxrkZ9jX-T~2*=89cBFC1|$CMTK^J*c(q0v66%}ebsDL1yL9gy1JQ-*9& zTa?;AD|g4nT1<^o9aINbaw^JL+@k?y)xm?ym`E0HnMJ$9T+mmvxy?E(1ULekDdLFfe%RFZLKHSAyF^rxNGt97l3mj_Iqs~Ta2PVeUT5tN|Jge2!IqP0 ziDOU}mJF$#o}~VyoU9pLXE&h+dZeWhlI)etwQ;ChQQNcq1?hl+5@D;`^}46^LOTf7 zy8{SWhnOqUqKJhhG#O5o=nNAPR5P8K9O^wSl0>#Q>Fgqa8mF_*Nr&h&d_{tE4zfr! zZ$wSENTwGN5R1}i1%3?02}#8T5E(6J51Sl7Vw%LYDAphO@sbf(lw!43%WJwZS6YSe zsWz3qTFP7KFNZ2vLp<6zVY3WyoCETw;PPLk$>+NdKwTVTC`Dq&bLCPBUxd z2Nitj{ps>tt^Dys`L$|!{>4lz`lMweIyivga>77FWOfY; zuD&GfyBLM-mP5l`w7F5(ROZIyYiu+7Ixk=c)oQcW@Tygj zT-9nLblU-olt$Gmq)@f$3?VNPP-L8Qg2j_eh%<*+>*FSyOeucfV-9Z9m(ub?itDKRqk2I>{D)8GNW{p#)-VENM=H9^nT*fQu(*m z=|FZIQMl1!D~Rh!wQkxc4zyHh5)Hs*y4~flS#E{er@2xK$_ruCsKsUIM&ocE+thT$ zTZjAUKGKe?g~`fhwsOZSCb@UPqSv(z*o~%VQ+soP9k&iH|0VmJfVJ2NSejQB9e)m8 z;19d@M@wGQuGxAijN;zDKm)or;M4o?Ud$N6__tr1m^i1&AT!;W+|bjolVTgTfK3y? zK00>PZh4U(x@Klnlv$xS&)l=#M@Ok8W27^m*RNsb=^(IyP;I7y{p3eBZ4QkE>)5{M z*>q=te{9h2jeBU?DYA}oEeLohTA$y53T^^b9NwH4auZe& z3^ACHM;P@KXi?J543 zIm<{|368d!{>#TmXMa_kw=CAr?+QNsBP7P0qb5eY%J^?gGwJ~3>k6E%Q#5KTgyodJ zCm|y^<>=3leCR`@R2SM!@q-X= zGxAB>ZzZ7APf2LR1o%=nXZjJRNBo*n;^mC71m6evvt=q$GhtnCB_V>X-JPz3V~axr z)&!xwRSR^_2tK9Pti-4!f+s>vbg@~pDDF6Yb`7kv-M@71eso)bWu$cv*xT>Uj zVS%NZt{P1$q-_LxRk-`YUtyESXQZeo=SOd#a zsg(hxFDXb%1A&C3Oogv8FX)BY3O7iw=fc-PFQPGHHG#VRj)x-x#kq(Hh18RJWJvj{vAu|rOb|UO6?)sIGK85H-Cd_iIKQu z=n~BKM2968Qk|M`Vbi?dfHK3DmGcr4idp9^B(o)k;*6HM#+)cEUJe<0Bd!AKx|Z`c zfkK?uk?RJ@%wYRiYcqL~$uT4%4C?I&dXDB@37~Bh>kgfxPw^$%X+-qWU}nfV49#@d z8ZhB*3U+QP%a-d%3|)dke+ECLDAajSs1tS)Kna&4YCLYBIrUR$dGLNl{RBOBs2Mb2 ztGrw!yk{sIaujiHIIRPmfU#=G0|vG;7KMS(A@z{{Ba)8v?9iYI#wJ!#30DmElznmH zG(A1IIA|^2P1nwO0!v9usE)bR9pQ)_-*vwm;RsvKqxa_!#fNj~k za|KiW317lJ1%#AyncY^-I*R{?@ja?Gq1ZZ79kicHb!^J;6g-TFy$p1|N9cn)PzR0x z8-X13vWRc+3#dYrC$EOUGQ8~p%ixklScV!!{{czY#^_<14rfSMpyR*Wr^Zg%T}ppU zUV)2l+!OHJcCd@kJCiYo?tCv1(TgyWbR)4iB9)&*t#}>SpSjt1o9(-Dc`MpMi|t_< zX9gqxjid9PNA*+aK8NoB1TIL1{6GtZ0+$4)e~2U&mA^J2=W9XJpP<^u@j~r`c1SfD z-fyh#{mt&=5j|o6N9+p&>XP>?@IGP4(a7 zVm%8Yq8R{d!!D&!5?BIh-VX0A;|8%yxL;pHW6R3n0fKsJqbV8po{Q`MyyS_rZi>(MVN^^r!?5T3^$Y5Z1N71vq+32LuQ+(@pX*+Z;s9z zKvl>?@$+6RZ}s8(KcLEgVb#5nfF80;?b0-yg-B=(p+Cf~OxysW!!EUaBVz8^Q0U7* z!d4XOIIO5JxrIdTcMMbXP~%Bfe@m!w3TqjKjQ`)%NVUdKWaMsiSu`>dw{6gY(37w* z5x$>aAIc8zn*%)cQI zKTr$TuU|!<({Z~7)A;d5S^XTV#u~+j(*FdKHigf6cmzSUo)Gxvi29Rya9?c#JQi=H zQ+{pXE!Loq-78J0>w5Rvi<8My4QyLa+Rg>QyL=;d9AcWd*(mX*;WW&d7|x|#qT3<& zpc0+o^>2O258TLW_AnQXkZIfyEH-CxDJjZQ z39kv92`03z%-fNw&sTDCv4Bfjy1zGXRIcB;eq(m}o$9++E0yVMSF6`&Bx?5kt9NeS zaAtbI&K(X!yGXA$?uCJvbS>OIMSi?4=x%7d$D>dPTk?*`c>{UpO(xS!7}awwGoj-v z_PbbP4waxo+2e5MU6m`j?QtGx Date: Tue, 29 May 2018 14:56:37 +0800 Subject: [PATCH 07/12] fix a bug --- python/paddle/fluid/layers/nn.py | 5 +++-- tools/codestyle/docstring_checker.pyc | Bin 0 -> 11769 bytes 2 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 tools/codestyle/docstring_checker.pyc diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index c337e0f4f2..3f04dcccd6 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -3990,15 +3990,16 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None): return out -def random_crop(input, shape, seed=0): +def random_crop(input, shape, seed=1): helper = LayerHelper("random_crop", **locals()) dtype = helper.input_dtype() out = helper.create_tmp_variable(dtype) if isinstance(seed, int): + seed_value = seed seed = helper.create_global_variable( persistable=True, shape=[1], dtype="int32") helper.set_variable_initializer( - var=seed, initializer=Constant(value=seed)) + var=seed, initializer=Constant(value=seed_value)) elif not isinstance(seed, Variable): raise ValueError("'seed' must be a Variable or an int.") seed_out = helper.create_tmp_variable(dtype="int32") diff --git a/tools/codestyle/docstring_checker.pyc b/tools/codestyle/docstring_checker.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ce612ca2318ccb9b9f28d51cb93ce8e5e1d0680 GIT binary patch literal 11769 zcmdT~Npl;=7488*kbt;or54N9SSu6>B(->BXxWq~%SvoYHISncEfx%6hUAa~3^+Zo zNn5H^N~tWz+>(Egs#N8a$}N>ceDg86+w?9wB4|7 zo0g|EGV6gi=Klvjv!x7%i8N5EO`nkfL%br6Q-8M3Sz+ zt_L-(?WHhqwH+^ccB>YVI;Ul&m$Ad>EYaPfYpqgXO>iI zxeCX~^PVl`kUCj5W-h>a;=n;1LQ3_(t3~FFSgXBY;c1*uwdyx<#MNV*lM^dz#7eiw z#a>~}G?K**nSwH>x`j+{A=_KXsVQIjROK)$C9=${Y<@ag)IuRw*J_Pc06OYNo#Vyr zsHtNTsZ;Z{r-cC=o|<&4p_|~y83bsX@#BzlPvO&?Us*%Rub@dNf#~6$Qc572fSm_f z(!>1>`S?Ljt(DY$5Jpbf0Ig(HG?kQ(JOEfI`xSIq-KTD&R4H-TSzt~npY&mfg22FQ z3JzuWf+fRtlJ%P|I96yW@ILecVYD_6SNk7(+l(ujY|)FWGW8hNs(C`n%3Fih9;;{# zXX?ZfPnlFEM0Fm!GlK6SS^`DOkbFu#0MicUaIgtqTEd!yb?o)AA-R~oE zAqrM36}_ugkEx~^Oj;7}dA3=Gh{&;J^f6_{{k&SpaA>qoYV%V2bIOe^Y6qnDx0E4U z)E1@oPs-h~u@+O~R0q|;m7Iz)7WZgCS#|K>GA5G6TV~PjFcB$}SOAB+^Q|zGT<5O^!QjCme>1u-93;!+*9;d$8qX zTH+X#g(X93rzfdDDJN@2*V#>|fgWjTgd}?2~CERB|5`I1l3GuCWm@YizJclO**>>pvLL!bJ8LD3}2BTor5e= z%^OkEEt2U)1jM2=T7e%!aY9mY0Ypa2*~2CWkeDWMEsFI=e!OG^7NuCN)$*Ed%#~Il zysE6N#_KIl#2JxMlI4Sbl1(P?!mH8sMT9 zl#gOwc_{X){ct!JKv@O#vRf3`4#41w0yYHLu)szD+bOVJfbAC89>DerY#(5w0^1MR z0f8L^?2y2o1MINCo(Jp&fgJ(N7TAmGW%|Km;shTRPi9cL$0P#>__$=?^PZ5*h;mO# zW~XvbDYqnU(k|tm7T|8>eoHcYlzT=pdzE`uGW(QUmdq&Kq;VqeDw3H{8@->nv{e4B zbvlq8M-*lL7%|@@A_IJ z0z(WYUAbX8Xqscwu6W}~#~^GjA`Eir^>Y*HBrl190)!`Rj4Nb& zQ4~gJ8SJqy+7iI9L2g(r9VJcaRyO&vwA8e#aOzt$PG@>qA4j7*VVjF>2d<5Pi2Xrg ze||tj)We1Z3(WXD({f^CXmZspShHTIGftw`(>pQZdq+tpens&3$i!j;e9$rfdwYui zWzI5^R)V9grvKtG(%D}X=Pir%^Sgpi{{V?G=ctJhuQL7{(~LR*`MLt9>lBUJ3Sl{= zFG>MpJ`^91`War4)JcDMIdF}jzTxI&FrIil8_V#*(t3&VaXI=^Bp><^DbyX;Mn5O zfHgsAZ`A_bGlEa)H7hYHiQtJ)6J2c9EQ&i0DWSlJY)o&@6dfu@$7VuRDau-6lHgW) z9vX2T4i(c16E7s8O0k3uKnS77UZEQ09%b0BmfscyAJd$EIH>Sj4xTVNgzF_GN6WU!vUcG&+m_ zMsI*R!E8VxrC_aJec2X>;tglO0+==t5-1qkoe~2DL_SPS&ry-?Vk-DV3v6ZqD6T50 zURYqMrmIGi3TYdGUKQ?sx+g8f^=qPr4A==5rbJ2yL5Q@54EDTiQY)2}hg6Jv3)aB0 zRBB~F=}QXI(m)^~DO2HV%nN#9w!#e(?78rD(2HoySWTd=zvJPEKyfZ&LLv1e?A-K9 zT%ki#!C@@CD?6mcEbuF#4G|4caMcE4Gc~wwG_plzJ!Q;Z#=O75mjq%d2##AvET4o@ zrwX9g!cP$VQ)CQ+Va)@L(Vi5uP&`0cmMDIRJVLW`3@pyfI}%m|A|)~hG??fh8UgVK z7)!)oMvd|evML(?f-~rsGk~0~td6VOJkOFihJVMBdMUG`oKkxTH%_MB*v(&~T4E$F z8M*|sJ<(ywg;b{|T-Y@4*PzU>W#zoYgksit3(0JWp*W+Zt}!P{izB1GBemd*4j*7WO59N2!nb%f}W#!R|04o#kxc1=y&)M?KC2KX)rTn9foE) zYz>%jHw8O4m1WCyB!(_Qp+AA2QWWYuDAWl%37~|_5j7q+(46`)v^;n}qke=QJJbxC zuvK0z65cbE4LOQ9H=Nc1PQX|-U|=#Y9y{{cxydUj~g1Y;AcsDvv9d&<5z zahjeUTpY9(@1|?#Jb|SoCRE2<>W=V4#I>cN#M7-)qZ7eI?gf0BK@aGZ`MnGj0h>m| z0$bm+Rz;?EF5=izMBM_2MPwOJq}14i{A&g#I%WgJiPW|jJmU9FEqYCFtrcDRE2lBk zqPe=Fq(cHK4*8!WKJCyo?ByjPsz)0e;cl~Wm&f*?DDh?8Zfxc)PAdT_8C-VUMZmV~ z-MNA(|A;T)o&rM3xy)`WXC1}=!}uOmn^0^WsSestr8+iccnTiI!(IkD-y`(F9jF6G zfQ>*7dRfFb_ytrU%9B?^U>V-_fMsyWA}m9VqJNL1Yh&~@KB0 zCa=InH|_~|ZadgT=$*-!LwCNHi0DNaNxG3(9FfY;p;o*O?9be6yv_Dqxx5wapvCsE zj5C9g|HjdI&!hS&bf3d_00I{zLw=xzLV-&H)89uDi^^Y{kn^>m=}%DY<9MO=K|7?H z4DUBqck%)e77h%An<(%mR$n2Yg4iy7Rp|DVz7TT$L}&;h!cAsa!O##K5B8Nvv!?p* zaj~8S5z!2QwPBaiCz;(4YtTh_L?hOwTvHk|fg;RAo>Lm^UWS`VY&Lm^$yp>uk|DEA)c86^{x?VG z4WKIIq4;?(mbd!w{U1=}zp(1wNI(x+rgmu>&O#)#hR`2kS0-+N&|#Naz7{d}Y$)_) zAYm&CbsSbynA}1l_dA9udZ_UvtG^-CIEA&0LdO4ZYNT4@Co*z3x-1$QiQ6{lK8<00ZTf3N)3GtuGqD=87sh3F5KLRs)qf`#qt3L6!s2PS-Zx$X6E0J zhaaeg>({TM&*->agK7MDqpW_0Rb!1}L+O74Nt?oFJv@S-T2Bakb42|~J-9D70UnDt z(kZ{P@D^*($L^J;)OEdk?ZwIDsRp*KCvE2f;9b5EI}R~T+-#J1({LK*ObqAJF466f zdr*na@cOsD zrG9IOm&g0%+Bjb*42~8?3wsKO3WwxfmEhetO8%OW{3DpfTzQUlZW*Z?l!H@N61Gli z Date: Wed, 30 May 2018 10:48:57 +0800 Subject: [PATCH 08/12] Add .cu --- paddle/fluid/operators/random_crop_op.cc | 2 +- paddle/fluid/operators/random_crop_op.cu | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/random_crop_op.cu diff --git a/paddle/fluid/operators/random_crop_op.cc b/paddle/fluid/operators/random_crop_op.cc index d92b8bbbb5..b14b559e31 100644 --- a/paddle/fluid/operators/random_crop_op.cc +++ b/paddle/fluid/operators/random_crop_op.cc @@ -74,8 +74,8 @@ namespace ops = paddle::operators; namespace f = paddle::framework; REGISTER_OPERATOR(random_crop, ops::RandomCropOp, ops::RandomCropOpMaker, ops::RandomCropOpInferShape, f::EmptyGradOpMaker); + template using Kernel = ops::RandomCropKernel; - REGISTER_OP_CPU_KERNEL(random_crop, Kernel, Kernel, Kernel, Kernel, Kernel); diff --git a/paddle/fluid/operators/random_crop_op.cu b/paddle/fluid/operators/random_crop_op.cu new file mode 100644 index 0000000000..2782911b4f --- /dev/null +++ b/paddle/fluid/operators/random_crop_op.cu @@ -0,0 +1,20 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/random_crop_op.h" + +template +using Kernel = ops::RandomCropKernel; +REGISTER_OP_CUDA_KERNEL(random_crop, Kernel, Kernel, Kernel, + Kernel, Kernel); From 7c42e5de1a4db1662db19a1b62c63ce98de713a6 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 30 May 2018 13:37:41 +0800 Subject: [PATCH 09/12] Polish RandomCropOp --- paddle/fluid/operators/random_crop_op.cc | 2 +- paddle/fluid/operators/random_crop_op.cu | 20 +++++++++++++++++++ python/paddle/fluid/layers/nn.py | 5 +++-- .../paddle/fluid/tests/unittests/op_test.py | 2 -- 4 files changed, 24 insertions(+), 5 deletions(-) create mode 100644 paddle/fluid/operators/random_crop_op.cu diff --git a/paddle/fluid/operators/random_crop_op.cc b/paddle/fluid/operators/random_crop_op.cc index d92b8bbbb5..b14b559e31 100644 --- a/paddle/fluid/operators/random_crop_op.cc +++ b/paddle/fluid/operators/random_crop_op.cc @@ -74,8 +74,8 @@ namespace ops = paddle::operators; namespace f = paddle::framework; REGISTER_OPERATOR(random_crop, ops::RandomCropOp, ops::RandomCropOpMaker, ops::RandomCropOpInferShape, f::EmptyGradOpMaker); + template using Kernel = ops::RandomCropKernel; - REGISTER_OP_CPU_KERNEL(random_crop, Kernel, Kernel, Kernel, Kernel, Kernel); diff --git a/paddle/fluid/operators/random_crop_op.cu b/paddle/fluid/operators/random_crop_op.cu new file mode 100644 index 0000000000..2782911b4f --- /dev/null +++ b/paddle/fluid/operators/random_crop_op.cu @@ -0,0 +1,20 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/random_crop_op.h" + +template +using Kernel = ops::RandomCropKernel; +REGISTER_OP_CUDA_KERNEL(random_crop, Kernel, Kernel, Kernel, + Kernel, Kernel); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index c337e0f4f2..3f04dcccd6 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -3990,15 +3990,16 @@ def upsampling_bilinear2d(input, out_shape=None, scale=None, name=None): return out -def random_crop(input, shape, seed=0): +def random_crop(input, shape, seed=1): helper = LayerHelper("random_crop", **locals()) dtype = helper.input_dtype() out = helper.create_tmp_variable(dtype) if isinstance(seed, int): + seed_value = seed seed = helper.create_global_variable( persistable=True, shape=[1], dtype="int32") helper.set_variable_initializer( - var=seed, initializer=Constant(value=seed)) + var=seed, initializer=Constant(value=seed_value)) elif not isinstance(seed, Variable): raise ValueError("'seed' must be a Variable or an int.") seed_out = helper.create_tmp_variable(dtype="int32") diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index b7e62533b3..b611470fa1 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -336,8 +336,6 @@ class OpTest(unittest.TestCase): actual_t = np.array(actual) expect = self.outputs[out_name] expect_t = expect[0] if isinstance(expect, tuple) else expect - import pdb - pdb.set_trace() self.assertTrue( np.allclose( actual_t, expect_t, atol=atol), From 45530c772e484d5033ceee90a034278f09ada6ba Mon Sep 17 00:00:00 2001 From: yuyang18 Date: Wed, 30 May 2018 13:45:09 +0800 Subject: [PATCH 10/12] Fix GPU compile --- paddle/fluid/operators/random_crop_op.cu | 1 + paddle/fluid/operators/random_crop_op.h | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/random_crop_op.cu b/paddle/fluid/operators/random_crop_op.cu index 2782911b4f..6fc9bedc55 100644 --- a/paddle/fluid/operators/random_crop_op.cu +++ b/paddle/fluid/operators/random_crop_op.cu @@ -14,6 +14,7 @@ #include "paddle/fluid/operators/random_crop_op.h" +namespace ops = paddle::operators; template using Kernel = ops::RandomCropKernel; REGISTER_OP_CUDA_KERNEL(random_crop, Kernel, Kernel, Kernel, diff --git a/paddle/fluid/operators/random_crop_op.h b/paddle/fluid/operators/random_crop_op.h index a34294f5ee..e0e24a7d1f 100644 --- a/paddle/fluid/operators/random_crop_op.h +++ b/paddle/fluid/operators/random_crop_op.h @@ -60,7 +60,7 @@ HOSTDEVICE inline void StridedMemcpy(const T* x, const size_t* x_dims, T* out, size_t offset_i = offsets[i]; if (i == rank - 1) { - PADDLE_ENFORCE(x_stride == 1 && out_stride == 1); + PADDLE_ASSERT(x_stride == 1 && out_stride == 1); x += offset_i; for (size_t j = 0; j < out_dim_i; ++j) { *out++ = *x++; @@ -105,12 +105,12 @@ struct RandomCropFunctor { prod_batchsize_dims_ = 1; prod_x_ins_dims_ = 1; prod_out_ins_dims_ = 1; - for (size_t i = 0; i < rank_; ++i) { + for (size_t i = 0; i < static_cast(rank_); ++i) { size_t x_dim_i = x_dims[i]; size_t out_dim_i = out_dims[i]; x_dims_[i] = x_dim_i; out_dims_[i] = out_dim_i; - if (i < num_batchsize_dims_) { + if (i < static_cast(num_batchsize_dims_)) { PADDLE_ENFORCE_EQ(x_dim_i, out_dim_i); prod_batchsize_dims_ *= x_dim_i; } else { From 3bce3dbce14a0ec4f65f54ed73cd268e3f5964ce Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 30 May 2018 15:13:38 +0800 Subject: [PATCH 11/12] fix a bug --- paddle/fluid/framework/operator.cc | 1 + python/paddle/fluid/layers/nn.py | 16 +++++++++++----- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index d70f26026c..30f784598a 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -466,6 +466,7 @@ class RuntimeInferShapeContext : public InferShapeContext { protected: DDim GetDim(const std::string& name) const override { Variable* var = scope_.FindVar(name); + PADDLE_ENFORCE_NOT_NULL(var); if (var->IsType()) { return var->Get().dims(); } else if (var->IsType()) { diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 3f04dcccd6..ec95efd699 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -3996,13 +3996,19 @@ def random_crop(input, shape, seed=1): out = helper.create_tmp_variable(dtype) if isinstance(seed, int): seed_value = seed - seed = helper.create_global_variable( - persistable=True, shape=[1], dtype="int32") - helper.set_variable_initializer( - var=seed, initializer=Constant(value=seed_value)) + seed = helper.create_tmp_variable(dtype="int64") + helper.append_op( + type="fill_constant", + inputs={}, + outputs={"Out": seed}, + attrs={ + "dtype": seed.dtype, + "shape": [1], + "value": float(seed_value) + }) elif not isinstance(seed, Variable): raise ValueError("'seed' must be a Variable or an int.") - seed_out = helper.create_tmp_variable(dtype="int32") + seed_out = helper.create_tmp_variable(dtype="int64") helper.append_op( type="random_crop", inputs={"X": input, From a6c11a5d95e9f1b62589a42305e5a9b97a4194f5 Mon Sep 17 00:00:00 2001 From: yuyang18 Date: Wed, 30 May 2018 16:22:41 +0800 Subject: [PATCH 12/12] Fix bug in CUDA --- paddle/fluid/operators/random_crop_op.h | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/random_crop_op.h b/paddle/fluid/operators/random_crop_op.h index e0e24a7d1f..f3261cbdc9 100644 --- a/paddle/fluid/operators/random_crop_op.h +++ b/paddle/fluid/operators/random_crop_op.h @@ -67,7 +67,7 @@ HOSTDEVICE inline void StridedMemcpy(const T* x, const size_t* x_dims, T* out, } } else { x += offset_i * x_stride; - for (size_t j = 0; j < x_dim_i; ++j) { + for (size_t j = 0; j < out_dim_i; ++j) { StridedMemcpy(x, x_dims, out, out_dims, i + 1, rank, x_stride, out_stride, offsets); x += x_stride; @@ -86,8 +86,6 @@ struct RandomCropFunctor { int rank_; int64_t seed_; - size_t prod_x_dims_; - size_t prod_out_dims_; size_t prod_batchsize_dims_; size_t prod_x_ins_dims_; size_t prod_out_ins_dims_; @@ -118,8 +116,6 @@ struct RandomCropFunctor { prod_out_ins_dims_ *= out_dim_i; } } - prod_x_dims_ = prod_batchsize_dims_ * prod_x_ins_dims_; - prod_out_dims_ = prod_batchsize_dims_ * prod_out_ins_dims_; } HOSTDEVICE void operator()(size_t ins_idx) { @@ -146,7 +142,17 @@ template class RandomCropKernel : public framework::OpKernel { public: virtual void Compute(const framework::ExecutionContext& ctx) const { - int64_t seed = *ctx.Input("Seed")->data(); + auto& seed_tensor = detail::Ref(ctx.Input("Seed")); + int64_t seed = 0; + if (platform::is_cpu_place(seed_tensor.place())) { + seed = *seed_tensor.data(); + } else { + LOG(WARNING) << "It is slow to place seed in GPU memory. Please verify " + "your program"; + framework::LoDTensor cpu_seed; + framework::TensorCopySync(seed_tensor, platform::CPUPlace(), &cpu_seed); + seed = *cpu_seed.data(); + } auto shape = ctx.Attr>("shape"); auto& x = detail::Ref(ctx.Input("X")); auto& out = detail::Ref(ctx.Output("Out"));