[OP] Add randperm op (#23292)
parent
08e3d9c0dc
commit
9297f49e4b
@ -0,0 +1,96 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/randperm_op.h"
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class RandpermOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
||||
platform::errors::NotFound(
|
||||
"The output(Out) of randperm op must not be null."));
|
||||
int n = ctx->Attrs().Get<int>("n");
|
||||
PADDLE_ENFORCE_GT(
|
||||
n, 0, platform::errors::InvalidArgument(
|
||||
"The input(n) of randperm op must be greater than 0."));
|
||||
|
||||
ctx->SetOutputDim("Out", framework::make_ddim({n}));
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext &ctx) const override {
|
||||
auto data_type =
|
||||
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));
|
||||
return framework::OpKernelType(data_type, ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
class RandpermOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddOutput("Out", "The output tensor of randperm op.");
|
||||
|
||||
AddAttr<int>(
|
||||
"n", "The upper bound (exclusive), and it should be greater than 0.");
|
||||
AddAttr<int>("dtype",
|
||||
"The data type of output tensor. "
|
||||
"Default: 3[int64].")
|
||||
.SetDefault(framework::proto::VarType::INT64);
|
||||
AddAttr<int>("seed",
|
||||
"Random seed used for permute samples. "
|
||||
"0 means use a seed generated by the system."
|
||||
"Note that if seed is not 0, this operator will always "
|
||||
"generate the same random permutation every time. "
|
||||
"Default: 0.")
|
||||
.SetDefault(0);
|
||||
|
||||
AddComment(R"DOC(
|
||||
This operator returns a random permutation of integers from 0 to n-1.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class RandpermOpVarTypeInference : public framework::VarTypeInference {
|
||||
public:
|
||||
void operator()(framework::InferVarTypeContext *ctx) const override {
|
||||
auto var_data_type = static_cast<framework::proto::VarType::Type>(
|
||||
boost::get<int>(ctx->GetAttr("dtype")));
|
||||
auto out_var_name = ctx->Output("Out").front();
|
||||
ctx->SetDataType(out_var_name, var_data_type);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_OPERATOR(
|
||||
randperm, paddle::operators::RandpermOp, paddle::operators::RandpermOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
|
||||
paddle::operators::RandpermOpVarTypeInference);
|
||||
|
||||
template <typename T>
|
||||
using kernel =
|
||||
paddle::operators::RandpermKernel<paddle::platform::CPUDeviceContext, T>;
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(randperm, kernel<int64_t>, kernel<int>);
|
||||
@ -0,0 +1,23 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/operators/randperm_op.h"
|
||||
|
||||
template <typename T>
|
||||
using kernel =
|
||||
paddle::operators::RandpermKernel<paddle::platform::CUDADeviceContext, T>;
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(randperm, kernel<int64_t>, kernel<int>);
|
||||
@ -0,0 +1,65 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <ctime>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/framework/tensor_util.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
static inline void random_permate(T* data_ptr, int num, unsigned int seed) {
|
||||
for (int i = 0; i < num; ++i) {
|
||||
data_ptr[i] = static_cast<T>(i);
|
||||
}
|
||||
if (seed == 0) {
|
||||
seed = std::random_device()();
|
||||
}
|
||||
std::srand(seed);
|
||||
std::random_shuffle(data_ptr, data_ptr + num);
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class RandpermKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
int n = ctx.Attr<int>("n");
|
||||
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
|
||||
framework::Variable* out_var = ctx.OutputVar("Out");
|
||||
framework::Tensor* out_tensor =
|
||||
framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(out_var);
|
||||
|
||||
if (platform::is_cpu_place(ctx.GetPlace())) {
|
||||
T* out_data = out_tensor->mutable_data<T>(platform::CPUPlace());
|
||||
random_permate<T>(out_data, n, seed);
|
||||
} else {
|
||||
framework::Tensor tmp_tensor;
|
||||
tmp_tensor.Resize(framework::make_ddim({n}));
|
||||
T* tmp_data = tmp_tensor.mutable_data<T>(platform::CPUPlace());
|
||||
random_permate<T>(tmp_data, n, seed);
|
||||
framework::TensorCopy(tmp_tensor, platform::CUDAPlace(), out_tensor);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
@ -0,0 +1,175 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.core as core
|
||||
from paddle.fluid.op import Operator
|
||||
from paddle.fluid import Program, program_guard
|
||||
|
||||
|
||||
def check_randperm_out(n, data_np):
|
||||
assert isinstance(data_np, np.ndarray), \
|
||||
"The input data_np should be np.ndarray."
|
||||
gt_sorted = np.arange(n)
|
||||
out_sorted = np.sort(data_np)
|
||||
return list(gt_sorted == out_sorted)
|
||||
|
||||
|
||||
def error_msg(data_np):
|
||||
return "The sorted ground truth and sorted out should " + \
|
||||
"be equal, out = " + str(data_np)
|
||||
|
||||
|
||||
def convert_dtype(dtype_str):
|
||||
dtype_str_list = ["int32", "int64"]
|
||||
dtype_num_list = [2, 3]
|
||||
assert dtype_str in dtype_str_list, dtype_str + \
|
||||
" should in " + str(dtype_str_list)
|
||||
return dtype_num_list[dtype_str_list.index(dtype_str)]
|
||||
|
||||
|
||||
class TestRandpermOp(OpTest):
|
||||
""" Test randperm op."""
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "randperm"
|
||||
self.n = 200
|
||||
self.dtype = "int64"
|
||||
self.device = None
|
||||
self.seed = 0
|
||||
|
||||
self.inputs = {}
|
||||
self.outputs = {"Out": np.zeros((self.n)).astype(self.dtype)}
|
||||
self.init_attrs()
|
||||
self.attrs = {
|
||||
"n": self.n,
|
||||
"dtype": convert_dtype(self.dtype),
|
||||
"device": self.device,
|
||||
"seed": self.seed,
|
||||
}
|
||||
|
||||
def init_attrs(self):
|
||||
pass
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output_customized(self.verify_output)
|
||||
|
||||
def verify_output(self, outs):
|
||||
out_np = np.array(outs[0])
|
||||
self.assertTrue(
|
||||
check_randperm_out(self.n, out_np), msg=error_msg(out_np))
|
||||
|
||||
|
||||
class TestRandpermOp_attr_n(TestRandpermOp):
|
||||
""" Test randperm op for attr n. """
|
||||
|
||||
def init_attrs(self):
|
||||
self.n = 10000
|
||||
|
||||
|
||||
class TestRandpermOp_attr_int32(TestRandpermOp):
|
||||
""" Test randperm op for attr int32 dtype. """
|
||||
|
||||
def init_attrs(self):
|
||||
self.dtype = "int32"
|
||||
|
||||
|
||||
class TestRandpermOp_attr_device_cpu(TestRandpermOp):
|
||||
""" Test randperm op for cpu device. """
|
||||
|
||||
def init_attrs(self):
|
||||
self.device = "cpu"
|
||||
|
||||
|
||||
class TestRandpermOp_attr_device_gpu(TestRandpermOp):
|
||||
""" Test randperm op for gpu device. """
|
||||
|
||||
def init_attrs(self):
|
||||
self.device = "gpu"
|
||||
|
||||
|
||||
class TestRandpermOp_attr_seed(TestRandpermOp):
|
||||
""" Test randperm op for attr seed. """
|
||||
|
||||
def init_attrs(self):
|
||||
self.seed = 10
|
||||
|
||||
|
||||
class TestRandpermOpError(unittest.TestCase):
|
||||
""" Test randperm op for raise error. """
|
||||
|
||||
def test_errors(self):
|
||||
main_prog = Program()
|
||||
start_prog = Program()
|
||||
with program_guard(main_prog, start_prog):
|
||||
|
||||
def test_Variable():
|
||||
out = np.arange(10)
|
||||
paddle.randperm(n=10, out=out)
|
||||
|
||||
self.assertRaises(TypeError, test_Variable)
|
||||
|
||||
def test_value():
|
||||
paddle.randperm(n=-3)
|
||||
|
||||
self.assertRaises(ValueError, test_value)
|
||||
|
||||
|
||||
class TestRandpermOp_attr_out(unittest.TestCase):
|
||||
""" Test randperm op for attr out. """
|
||||
|
||||
def test_attr_tensor_API(self):
|
||||
startup_program = fluid.Program()
|
||||
train_program = fluid.Program()
|
||||
with fluid.program_guard(train_program, startup_program):
|
||||
n = 10
|
||||
data_1 = fluid.layers.fill_constant([n], "int64", 3)
|
||||
paddle.randperm(n=n, out=data_1)
|
||||
|
||||
data_2 = paddle.randperm(n=n, dtype="int32", device="cpu")
|
||||
|
||||
place = fluid.CPUPlace()
|
||||
if fluid.core.is_compiled_with_cuda():
|
||||
place = fluid.CUDAPlace(0)
|
||||
exe = fluid.Executor(place)
|
||||
|
||||
exe.run(startup_program)
|
||||
outs = exe.run(train_program, fetch_list=[data_1, data_2])
|
||||
|
||||
out_np = np.array(outs[0])
|
||||
self.assertTrue(
|
||||
check_randperm_out(n, out_np), msg=error_msg(out_np))
|
||||
|
||||
|
||||
class TestRandpermDygraphMode(unittest.TestCase):
|
||||
def test_check_output(self):
|
||||
with fluid.dygraph.guard():
|
||||
n = 10
|
||||
data_1 = paddle.randperm(n, dtype="int64")
|
||||
data_1_np = data_1.numpy()
|
||||
self.assertTrue(
|
||||
check_randperm_out(n, data_1_np), msg=error_msg(data_1_np))
|
||||
|
||||
data_2 = paddle.randperm(n, dtype="int32", device="cpu")
|
||||
data_2_np = data_2.numpy()
|
||||
self.assertTrue(
|
||||
check_randperm_out(n, data_2_np), msg=error_msg(data_2_np))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in new issue