[OP] Add randperm op (#23292)
parent
08e3d9c0dc
commit
9297f49e4b
@ -0,0 +1,96 @@
|
|||||||
|
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/randperm_op.h"
|
||||||
|
#include <string>
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
#include "paddle/fluid/framework/operator.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
class RandpermOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||||
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
||||||
|
platform::errors::NotFound(
|
||||||
|
"The output(Out) of randperm op must not be null."));
|
||||||
|
int n = ctx->Attrs().Get<int>("n");
|
||||||
|
PADDLE_ENFORCE_GT(
|
||||||
|
n, 0, platform::errors::InvalidArgument(
|
||||||
|
"The input(n) of randperm op must be greater than 0."));
|
||||||
|
|
||||||
|
ctx->SetOutputDim("Out", framework::make_ddim({n}));
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
framework::OpKernelType GetExpectedKernelType(
|
||||||
|
const framework::ExecutionContext &ctx) const override {
|
||||||
|
auto data_type =
|
||||||
|
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));
|
||||||
|
return framework::OpKernelType(data_type, ctx.GetPlace());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class RandpermOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
void Make() override {
|
||||||
|
AddOutput("Out", "The output tensor of randperm op.");
|
||||||
|
|
||||||
|
AddAttr<int>(
|
||||||
|
"n", "The upper bound (exclusive), and it should be greater than 0.");
|
||||||
|
AddAttr<int>("dtype",
|
||||||
|
"The data type of output tensor. "
|
||||||
|
"Default: 3[int64].")
|
||||||
|
.SetDefault(framework::proto::VarType::INT64);
|
||||||
|
AddAttr<int>("seed",
|
||||||
|
"Random seed used for permute samples. "
|
||||||
|
"0 means use a seed generated by the system."
|
||||||
|
"Note that if seed is not 0, this operator will always "
|
||||||
|
"generate the same random permutation every time. "
|
||||||
|
"Default: 0.")
|
||||||
|
.SetDefault(0);
|
||||||
|
|
||||||
|
AddComment(R"DOC(
|
||||||
|
This operator returns a random permutation of integers from 0 to n-1.
|
||||||
|
)DOC");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class RandpermOpVarTypeInference : public framework::VarTypeInference {
|
||||||
|
public:
|
||||||
|
void operator()(framework::InferVarTypeContext *ctx) const override {
|
||||||
|
auto var_data_type = static_cast<framework::proto::VarType::Type>(
|
||||||
|
boost::get<int>(ctx->GetAttr("dtype")));
|
||||||
|
auto out_var_name = ctx->Output("Out").front();
|
||||||
|
ctx->SetDataType(out_var_name, var_data_type);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
REGISTER_OPERATOR(
|
||||||
|
randperm, paddle::operators::RandpermOp, paddle::operators::RandpermOpMaker,
|
||||||
|
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
|
||||||
|
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
|
||||||
|
paddle::operators::RandpermOpVarTypeInference);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using kernel =
|
||||||
|
paddle::operators::RandpermKernel<paddle::platform::CPUDeviceContext, T>;
|
||||||
|
|
||||||
|
REGISTER_OP_CPU_KERNEL(randperm, kernel<int64_t>, kernel<int>);
|
||||||
@ -0,0 +1,23 @@
|
|||||||
|
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
#include "paddle/fluid/framework/operator.h"
|
||||||
|
#include "paddle/fluid/operators/randperm_op.h"
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using kernel =
|
||||||
|
paddle::operators::RandpermKernel<paddle::platform::CUDADeviceContext, T>;
|
||||||
|
|
||||||
|
REGISTER_OP_CUDA_KERNEL(randperm, kernel<int64_t>, kernel<int>);
|
||||||
@ -0,0 +1,65 @@
|
|||||||
|
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <ctime>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "paddle/fluid/framework/operator.h"
|
||||||
|
#include "paddle/fluid/framework/tensor_util.h"
|
||||||
|
#include "paddle/fluid/platform/place.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static inline void random_permate(T* data_ptr, int num, unsigned int seed) {
|
||||||
|
for (int i = 0; i < num; ++i) {
|
||||||
|
data_ptr[i] = static_cast<T>(i);
|
||||||
|
}
|
||||||
|
if (seed == 0) {
|
||||||
|
seed = std::random_device()();
|
||||||
|
}
|
||||||
|
std::srand(seed);
|
||||||
|
std::random_shuffle(data_ptr, data_ptr + num);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
class RandpermKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
int n = ctx.Attr<int>("n");
|
||||||
|
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
|
||||||
|
framework::Variable* out_var = ctx.OutputVar("Out");
|
||||||
|
framework::Tensor* out_tensor =
|
||||||
|
framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(out_var);
|
||||||
|
|
||||||
|
if (platform::is_cpu_place(ctx.GetPlace())) {
|
||||||
|
T* out_data = out_tensor->mutable_data<T>(platform::CPUPlace());
|
||||||
|
random_permate<T>(out_data, n, seed);
|
||||||
|
} else {
|
||||||
|
framework::Tensor tmp_tensor;
|
||||||
|
tmp_tensor.Resize(framework::make_ddim({n}));
|
||||||
|
T* tmp_data = tmp_tensor.mutable_data<T>(platform::CPUPlace());
|
||||||
|
random_permate<T>(tmp_data, n, seed);
|
||||||
|
framework::TensorCopy(tmp_tensor, platform::CUDAPlace(), out_tensor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
@ -0,0 +1,175 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
from op_test import OpTest
|
||||||
|
import paddle
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
import paddle.fluid.core as core
|
||||||
|
from paddle.fluid.op import Operator
|
||||||
|
from paddle.fluid import Program, program_guard
|
||||||
|
|
||||||
|
|
||||||
|
def check_randperm_out(n, data_np):
|
||||||
|
assert isinstance(data_np, np.ndarray), \
|
||||||
|
"The input data_np should be np.ndarray."
|
||||||
|
gt_sorted = np.arange(n)
|
||||||
|
out_sorted = np.sort(data_np)
|
||||||
|
return list(gt_sorted == out_sorted)
|
||||||
|
|
||||||
|
|
||||||
|
def error_msg(data_np):
|
||||||
|
return "The sorted ground truth and sorted out should " + \
|
||||||
|
"be equal, out = " + str(data_np)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_dtype(dtype_str):
|
||||||
|
dtype_str_list = ["int32", "int64"]
|
||||||
|
dtype_num_list = [2, 3]
|
||||||
|
assert dtype_str in dtype_str_list, dtype_str + \
|
||||||
|
" should in " + str(dtype_str_list)
|
||||||
|
return dtype_num_list[dtype_str_list.index(dtype_str)]
|
||||||
|
|
||||||
|
|
||||||
|
class TestRandpermOp(OpTest):
|
||||||
|
""" Test randperm op."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.op_type = "randperm"
|
||||||
|
self.n = 200
|
||||||
|
self.dtype = "int64"
|
||||||
|
self.device = None
|
||||||
|
self.seed = 0
|
||||||
|
|
||||||
|
self.inputs = {}
|
||||||
|
self.outputs = {"Out": np.zeros((self.n)).astype(self.dtype)}
|
||||||
|
self.init_attrs()
|
||||||
|
self.attrs = {
|
||||||
|
"n": self.n,
|
||||||
|
"dtype": convert_dtype(self.dtype),
|
||||||
|
"device": self.device,
|
||||||
|
"seed": self.seed,
|
||||||
|
}
|
||||||
|
|
||||||
|
def init_attrs(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output_customized(self.verify_output)
|
||||||
|
|
||||||
|
def verify_output(self, outs):
|
||||||
|
out_np = np.array(outs[0])
|
||||||
|
self.assertTrue(
|
||||||
|
check_randperm_out(self.n, out_np), msg=error_msg(out_np))
|
||||||
|
|
||||||
|
|
||||||
|
class TestRandpermOp_attr_n(TestRandpermOp):
|
||||||
|
""" Test randperm op for attr n. """
|
||||||
|
|
||||||
|
def init_attrs(self):
|
||||||
|
self.n = 10000
|
||||||
|
|
||||||
|
|
||||||
|
class TestRandpermOp_attr_int32(TestRandpermOp):
|
||||||
|
""" Test randperm op for attr int32 dtype. """
|
||||||
|
|
||||||
|
def init_attrs(self):
|
||||||
|
self.dtype = "int32"
|
||||||
|
|
||||||
|
|
||||||
|
class TestRandpermOp_attr_device_cpu(TestRandpermOp):
|
||||||
|
""" Test randperm op for cpu device. """
|
||||||
|
|
||||||
|
def init_attrs(self):
|
||||||
|
self.device = "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
class TestRandpermOp_attr_device_gpu(TestRandpermOp):
|
||||||
|
""" Test randperm op for gpu device. """
|
||||||
|
|
||||||
|
def init_attrs(self):
|
||||||
|
self.device = "gpu"
|
||||||
|
|
||||||
|
|
||||||
|
class TestRandpermOp_attr_seed(TestRandpermOp):
|
||||||
|
""" Test randperm op for attr seed. """
|
||||||
|
|
||||||
|
def init_attrs(self):
|
||||||
|
self.seed = 10
|
||||||
|
|
||||||
|
|
||||||
|
class TestRandpermOpError(unittest.TestCase):
|
||||||
|
""" Test randperm op for raise error. """
|
||||||
|
|
||||||
|
def test_errors(self):
|
||||||
|
main_prog = Program()
|
||||||
|
start_prog = Program()
|
||||||
|
with program_guard(main_prog, start_prog):
|
||||||
|
|
||||||
|
def test_Variable():
|
||||||
|
out = np.arange(10)
|
||||||
|
paddle.randperm(n=10, out=out)
|
||||||
|
|
||||||
|
self.assertRaises(TypeError, test_Variable)
|
||||||
|
|
||||||
|
def test_value():
|
||||||
|
paddle.randperm(n=-3)
|
||||||
|
|
||||||
|
self.assertRaises(ValueError, test_value)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRandpermOp_attr_out(unittest.TestCase):
|
||||||
|
""" Test randperm op for attr out. """
|
||||||
|
|
||||||
|
def test_attr_tensor_API(self):
|
||||||
|
startup_program = fluid.Program()
|
||||||
|
train_program = fluid.Program()
|
||||||
|
with fluid.program_guard(train_program, startup_program):
|
||||||
|
n = 10
|
||||||
|
data_1 = fluid.layers.fill_constant([n], "int64", 3)
|
||||||
|
paddle.randperm(n=n, out=data_1)
|
||||||
|
|
||||||
|
data_2 = paddle.randperm(n=n, dtype="int32", device="cpu")
|
||||||
|
|
||||||
|
place = fluid.CPUPlace()
|
||||||
|
if fluid.core.is_compiled_with_cuda():
|
||||||
|
place = fluid.CUDAPlace(0)
|
||||||
|
exe = fluid.Executor(place)
|
||||||
|
|
||||||
|
exe.run(startup_program)
|
||||||
|
outs = exe.run(train_program, fetch_list=[data_1, data_2])
|
||||||
|
|
||||||
|
out_np = np.array(outs[0])
|
||||||
|
self.assertTrue(
|
||||||
|
check_randperm_out(n, out_np), msg=error_msg(out_np))
|
||||||
|
|
||||||
|
|
||||||
|
class TestRandpermDygraphMode(unittest.TestCase):
|
||||||
|
def test_check_output(self):
|
||||||
|
with fluid.dygraph.guard():
|
||||||
|
n = 10
|
||||||
|
data_1 = paddle.randperm(n, dtype="int64")
|
||||||
|
data_1_np = data_1.numpy()
|
||||||
|
self.assertTrue(
|
||||||
|
check_randperm_out(n, data_1_np), msg=error_msg(data_1_np))
|
||||||
|
|
||||||
|
data_2 = paddle.randperm(n, dtype="int32", device="cpu")
|
||||||
|
data_2_np = data_2.numpy()
|
||||||
|
self.assertTrue(
|
||||||
|
check_randperm_out(n, data_2_np), msg=error_msg(data_2_np))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in new issue