Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix-compile-by-std-move
commit
96be582ef3
@ -0,0 +1,81 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include "paddle/fluid/operators/random_crop_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class RandomCropOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "A batch of instances to random crop.");
|
||||
AddInput("Seed", "The random seed.");
|
||||
AddOutput("Out", "The cropped instance batch.");
|
||||
AddOutput("SeedOut", "The random seed after random cropping.")
|
||||
.AsDispensable();
|
||||
AddAttr<std::vector<int>>("shape", "The shape of a cropped instance.");
|
||||
AddComment(R"DOC(
|
||||
This operator takes a batch of instance, and do random cropping on each instance.
|
||||
It means that cropping positions differs on each instance, which is determined
|
||||
by an uniform random generator. All cropped instances have the same shape, which
|
||||
is determined by the operator's attribute 'shape'.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class RandomCropOpInferShape : public framework::InferShapeBase {
|
||||
public:
|
||||
void operator()(framework::InferShapeContext* ctx) const override {
|
||||
auto seed_dim = ctx->GetInputDim("Seed");
|
||||
PADDLE_ENFORCE(seed_dim.size() == 1 && seed_dim[0] == 1);
|
||||
auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
|
||||
auto x_dim = ctx->GetInputDim("X");
|
||||
PADDLE_ENFORCE_GT(x_dim.size(), static_cast<int64_t>(shape.size()));
|
||||
auto out_dim = framework::vectorize2int(x_dim);
|
||||
for (size_t i = 1; i <= shape.size(); ++i) {
|
||||
size_t x_i = x_dim.size() - i;
|
||||
size_t shape_i = shape.size() - i;
|
||||
PADDLE_ENFORCE_GE(x_dim[x_i], shape[shape_i]);
|
||||
out_dim[x_i] = shape[shape_i];
|
||||
}
|
||||
ctx->SetOutputDim("Out", framework::make_ddim(out_dim));
|
||||
ctx->SetOutputDim("SeedOut", framework::make_ddim({1}));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
namespace f = paddle::framework;
|
||||
REGISTER_OPERATOR(random_crop, ops::RandomCropOp, ops::RandomCropOpMaker,
|
||||
ops::RandomCropOpInferShape, f::EmptyGradOpMaker);
|
||||
|
||||
template <typename T>
|
||||
using Kernel = ops::RandomCropKernel<paddle::platform::CPUDeviceContext, T>;
|
||||
REGISTER_OP_CPU_KERNEL(random_crop, Kernel<float>, Kernel<int>, Kernel<double>,
|
||||
Kernel<uint8_t>, Kernel<int16_t>);
|
@ -0,0 +1,21 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/random_crop_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
template <typename T>
|
||||
using Kernel = ops::RandomCropKernel<paddle::platform::CUDADeviceContext, T>;
|
||||
REGISTER_OP_CUDA_KERNEL(random_crop, Kernel<float>, Kernel<int>, Kernel<double>,
|
||||
Kernel<uint8_t>, Kernel<int16_t>);
|
@ -0,0 +1,181 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/detail/safe_ref.h"
|
||||
#include "paddle/fluid/platform/device_context.h"
|
||||
#include "paddle/fluid/platform/for_range.h"
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
#include <thrust/random.h>
|
||||
#endif
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext>
|
||||
struct Random;
|
||||
|
||||
template <>
|
||||
struct Random<platform::CPUDeviceContext> {
|
||||
using Engine = std::minstd_rand;
|
||||
|
||||
template <typename T>
|
||||
using UniformIntDist = std::uniform_int_distribution<T>;
|
||||
};
|
||||
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
template <>
|
||||
struct Random<platform::CUDADeviceContext> {
|
||||
using Engine = thrust::minstd_rand;
|
||||
|
||||
template <typename T>
|
||||
using UniformIntDist = thrust::uniform_int_distribution<T>;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
HOSTDEVICE inline void StridedMemcpy(const T* x, const size_t* x_dims, T* out,
|
||||
const size_t* out_dims, int i, int rank,
|
||||
size_t prod_x_remain,
|
||||
size_t prod_out_remain,
|
||||
const size_t* offsets) {
|
||||
size_t x_dim_i = x_dims[i];
|
||||
size_t out_dim_i = out_dims[i];
|
||||
size_t x_stride = prod_x_remain / x_dim_i;
|
||||
size_t out_stride = prod_out_remain / out_dim_i;
|
||||
size_t offset_i = offsets[i];
|
||||
|
||||
if (i == rank - 1) {
|
||||
PADDLE_ASSERT(x_stride == 1 && out_stride == 1);
|
||||
x += offset_i;
|
||||
for (size_t j = 0; j < out_dim_i; ++j) {
|
||||
*out++ = *x++;
|
||||
}
|
||||
} else {
|
||||
x += offset_i * x_stride;
|
||||
for (size_t j = 0; j < out_dim_i; ++j) {
|
||||
StridedMemcpy<T>(x, x_dims, out, out_dims, i + 1, rank, x_stride,
|
||||
out_stride, offsets);
|
||||
x += x_stride;
|
||||
out += out_stride;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
struct RandomCropFunctor {
|
||||
const T* x_;
|
||||
T* out_;
|
||||
size_t x_dims_[9];
|
||||
size_t out_dims_[9];
|
||||
int num_batchsize_dims_;
|
||||
int rank_;
|
||||
int64_t seed_;
|
||||
|
||||
size_t prod_batchsize_dims_;
|
||||
size_t prod_x_ins_dims_;
|
||||
size_t prod_out_ins_dims_;
|
||||
|
||||
RandomCropFunctor(const T* x, T* out, const framework::DDim& x_dims,
|
||||
const framework::DDim& out_dims, int num_batchsize_dims,
|
||||
int64_t seed)
|
||||
: x_(x),
|
||||
out_(out),
|
||||
num_batchsize_dims_(num_batchsize_dims),
|
||||
rank_(x_dims.size()),
|
||||
seed_(seed) {
|
||||
PADDLE_ENFORCE_EQ(x_dims.size(), out_dims.size());
|
||||
PADDLE_ENFORCE_GT(rank_, num_batchsize_dims_);
|
||||
prod_batchsize_dims_ = 1;
|
||||
prod_x_ins_dims_ = 1;
|
||||
prod_out_ins_dims_ = 1;
|
||||
for (size_t i = 0; i < static_cast<size_t>(rank_); ++i) {
|
||||
size_t x_dim_i = x_dims[i];
|
||||
size_t out_dim_i = out_dims[i];
|
||||
x_dims_[i] = x_dim_i;
|
||||
out_dims_[i] = out_dim_i;
|
||||
if (i < static_cast<size_t>(num_batchsize_dims_)) {
|
||||
PADDLE_ENFORCE_EQ(x_dim_i, out_dim_i);
|
||||
prod_batchsize_dims_ *= x_dim_i;
|
||||
} else {
|
||||
prod_x_ins_dims_ *= x_dim_i;
|
||||
prod_out_ins_dims_ *= out_dim_i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
HOSTDEVICE void operator()(size_t ins_idx) {
|
||||
typename Random<DeviceContext>::Engine engine(seed_);
|
||||
engine.discard(ins_idx * (rank_ - num_batchsize_dims_));
|
||||
size_t offsets[9];
|
||||
for (int i = num_batchsize_dims_; i < rank_; ++i) {
|
||||
typename Random<DeviceContext>::template UniformIntDist<size_t> dist(
|
||||
0, x_dims_[i] - out_dims_[i]);
|
||||
offsets[i - num_batchsize_dims_] = dist(engine);
|
||||
}
|
||||
|
||||
const T* x = x_ + ins_idx * prod_x_ins_dims_;
|
||||
T* out = out_ + ins_idx * prod_out_ins_dims_;
|
||||
|
||||
StridedMemcpy<T>(x, x_dims_ + num_batchsize_dims_, out,
|
||||
out_dims_ + num_batchsize_dims_, 0,
|
||||
rank_ - num_batchsize_dims_, prod_x_ins_dims_,
|
||||
prod_out_ins_dims_, offsets);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class RandomCropKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
virtual void Compute(const framework::ExecutionContext& ctx) const {
|
||||
auto& seed_tensor = detail::Ref(ctx.Input<framework::LoDTensor>("Seed"));
|
||||
int64_t seed = 0;
|
||||
if (platform::is_cpu_place(seed_tensor.place())) {
|
||||
seed = *seed_tensor.data<int64_t>();
|
||||
} else {
|
||||
LOG(WARNING) << "It is slow to place seed in GPU memory. Please verify "
|
||||
"your program";
|
||||
framework::LoDTensor cpu_seed;
|
||||
framework::TensorCopySync(seed_tensor, platform::CPUPlace(), &cpu_seed);
|
||||
seed = *cpu_seed.data<int64_t>();
|
||||
}
|
||||
auto shape = ctx.Attr<std::vector<int>>("shape");
|
||||
auto& x = detail::Ref(ctx.Input<framework::LoDTensor>("X"));
|
||||
auto& out = detail::Ref(ctx.Output<framework::LoDTensor>("Out"));
|
||||
|
||||
int num_batchsize_dims = x.dims().size() - shape.size();
|
||||
RandomCropFunctor<DeviceContext, T> functor(
|
||||
x.data<T>(), out.mutable_data<T>(ctx.GetPlace()), x.dims(), out.dims(),
|
||||
num_batchsize_dims, seed);
|
||||
platform::ForRange<DeviceContext> for_range(
|
||||
ctx.template device_context<DeviceContext>(),
|
||||
functor.prod_batchsize_dims_);
|
||||
|
||||
for_range(functor);
|
||||
|
||||
Random<platform::CPUDeviceContext>::Engine engine(seed);
|
||||
engine.discard(functor.prod_batchsize_dims_ *
|
||||
(functor.rank_ - functor.num_batchsize_dims_));
|
||||
*ctx.Output<framework::LoDTensor>("SeedOut")->mutable_data<int64_t>(
|
||||
platform::CPUPlace()) = engine();
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(fengjiayi): Backward of random crop op
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,46 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import paddle.fluid.core as core
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestRandomCropOp(OpTest):
|
||||
def setUp(self):
|
||||
to_crop = np.array([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]] *
|
||||
5).astype("float32")
|
||||
self.possible_res = [
|
||||
np.array([[1, 2, 3], [5, 6, 7]]), np.array([[2, 3, 4], [6, 7, 8]]),
|
||||
np.array([[5, 6, 7], [9, 10, 11]]),
|
||||
np.array([[6, 7, 8], [10, 11, 12]])
|
||||
]
|
||||
self.op_type = "random_crop"
|
||||
self.inputs = {'X': to_crop, 'Seed': np.array([10])}
|
||||
self.outputs = {'Out': np.array([]), 'SeedOut': np.array([])}
|
||||
self.attrs = {'shape': [2, 3]}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output_customized(self.verify_output)
|
||||
|
||||
def verify_output(self, outs):
|
||||
out = np.array(outs[1])
|
||||
for ins in out[:]:
|
||||
is_equal = [(ins == res).all() for res in self.possible_res]
|
||||
self.assertIn(True, is_equal)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Binary file not shown.
Loading…
Reference in new issue