You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
171 lines
6.7 KiB
171 lines
6.7 KiB
5 years ago
|
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||
|
//
|
||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
// you may not use this file except in compliance with the License.
|
||
|
// You may obtain a copy of the License at
|
||
|
//
|
||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||
|
//
|
||
|
// Unless required by applicable law or agreed to in writing, software
|
||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
// See the License for the specific language governing permissions and
|
||
|
// limitations under the License.
|
||
|
|
||
|
#include <string>
|
||
|
#include <vector>
|
||
|
#include "paddle/fluid/framework/op_registry.h"
|
||
|
#include "paddle/fluid/framework/operator.h"
|
||
|
#include "paddle/fluid/operators/uniform_random_op.h"
|
||
|
#include "paddle/fluid/platform/enforce.h"
|
||
|
|
||
|
namespace paddle {
|
||
|
namespace operators {
|
||
|
|
||
|
template <typename T>
|
||
|
class CPURandintKernel : public framework::OpKernel<T> {
|
||
|
public:
|
||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||
|
std::vector<int64_t> new_shape;
|
||
|
auto list_new_shape_tensor =
|
||
|
ctx.MultiInput<framework::Tensor>("ShapeTensorList");
|
||
|
if (list_new_shape_tensor.size() > 0 || ctx.HasInput("ShapeTensor")) {
|
||
|
if (ctx.HasInput("ShapeTensor")) {
|
||
|
auto* shape_tensor = ctx.Input<framework::Tensor>("ShapeTensor");
|
||
|
new_shape = GetNewDataFromShapeTensor(shape_tensor);
|
||
|
} else if (list_new_shape_tensor.size() > 0) {
|
||
|
new_shape = GetNewDataFromShapeTensorList(list_new_shape_tensor);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
auto* out = ctx.Output<framework::LoDTensor>("Out");
|
||
|
if (!new_shape.empty()) out->Resize(framework::make_ddim(new_shape));
|
||
|
T* data = out->mutable_data<T>(ctx.GetPlace());
|
||
|
int64_t size = out->numel();
|
||
|
std::random_device rd;
|
||
|
std::mt19937 gen(rd());
|
||
|
std::uniform_int_distribution<> dist(ctx.Attr<int>("low"),
|
||
|
ctx.Attr<int>("high") - 1);
|
||
|
for (int64_t i = 0; i < size; ++i) data[i] = dist(gen);
|
||
|
}
|
||
|
};
|
||
|
|
||
|
class RandintOp : public framework::OperatorWithKernel {
|
||
|
public:
|
||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||
|
|
||
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
||
|
PADDLE_ENFORCE_EQ(
|
||
|
ctx->HasOutput("Out"), true,
|
||
|
platform::errors::InvalidArgument("Output(Out) of RandintOp is null."));
|
||
|
PADDLE_ENFORCE_LT(
|
||
|
ctx->Attrs().Get<int>("low"), ctx->Attrs().Get<int>("high"),
|
||
|
platform::errors::InvalidArgument("randint's low must less then high, "
|
||
|
"but received: low = %d, high = %d.",
|
||
|
ctx->Attrs().Get<int>("low"),
|
||
|
ctx->Attrs().Get<int>("high")));
|
||
|
|
||
|
if (ctx->HasInputs("ShapeTensorList")) {
|
||
|
// top prority shape
|
||
|
auto inputs_name = ctx->Inputs("ShapeTensorList");
|
||
|
PADDLE_ENFORCE_GT(
|
||
|
inputs_name.size(), 0,
|
||
|
platform::errors::InvalidArgument(
|
||
|
"Input(ShapeTensorList)'size of Op(randint) can't be zero."
|
||
|
"Please check the Attr(shape)'s size of"
|
||
|
"Op(fluid.layers.randint).)"));
|
||
|
auto out_dims = std::vector<int>(inputs_name.size(), -1);
|
||
|
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
|
||
|
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
auto& shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
|
||
|
if (ctx->HasInput("ShapeTensor") && shape.empty()) {
|
||
|
auto shape_dims = ctx->GetInputDim("ShapeTensor");
|
||
|
PADDLE_ENFORCE_EQ(shape_dims.size(), 1,
|
||
|
platform::errors::InvalidArgument(
|
||
|
"ShapeError: Input(ShapeTensor)' dimension size of "
|
||
|
"Op(randint) must be 1."
|
||
|
"But received ShapeTensor's dimensions = %d.",
|
||
|
shape_dims.size()));
|
||
|
int num_ele = 1;
|
||
|
for (int i = 0; i < shape_dims.size(); ++i) {
|
||
|
num_ele *= shape_dims[i];
|
||
|
}
|
||
|
auto vec_dims = std::vector<int64_t>(num_ele, -1);
|
||
|
auto out_dims = framework::make_ddim(vec_dims);
|
||
|
ctx->SetOutputDim("Out", out_dims);
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
PADDLE_ENFORCE_EQ(shape.empty(), false,
|
||
|
platform::errors::InvalidArgument(
|
||
|
"if there is no Input(ShapeTensorList) and no "
|
||
|
"Input(ShapeTensor),the "
|
||
|
"attr(shape) information must "
|
||
|
"be set by Attr(shape)."));
|
||
|
|
||
|
std::vector<int64_t> tensor_shape;
|
||
|
tensor_shape.reserve(shape.size());
|
||
|
for (auto dim : shape) {
|
||
|
tensor_shape.push_back(static_cast<int64_t>(dim));
|
||
|
}
|
||
|
ctx->SetOutputDim("Out", framework::make_ddim(tensor_shape));
|
||
|
}
|
||
|
|
||
|
protected:
|
||
|
framework::OpKernelType GetExpectedKernelType(
|
||
|
const framework::ExecutionContext& ctx) const override {
|
||
|
return framework::OpKernelType(
|
||
|
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
|
||
|
ctx.GetPlace());
|
||
|
}
|
||
|
};
|
||
|
|
||
|
class RandintOpMaker : public framework::OpProtoAndCheckerMaker {
|
||
|
public:
|
||
|
void Make() override {
|
||
|
AddInput("ShapeTensor",
|
||
|
"(Tensor<int64_t> or Tensor<int32_t>, optional) . If provided, "
|
||
|
"randint"
|
||
|
"according to "
|
||
|
"this given shape. It means that it has a higher priority than "
|
||
|
"Attr(shape) but a lower priority than Input(ShapeTensor).")
|
||
|
.AsDispensable();
|
||
|
AddInput("ShapeTensorList",
|
||
|
"(vector<Tensor<int64_t>> or vector<Tensor<int32_t>>, optional). "
|
||
|
"If provided, randint use this. The shape of the tensor "
|
||
|
"must be [1], it has the highest priority comparing with "
|
||
|
"Input(ShapeTensor) and attr(shape).")
|
||
|
.AsDuplicable()
|
||
|
.AsDispensable();
|
||
|
AddOutput("Out", "The output tensor of randint op");
|
||
|
AddComment(R"DOC(
|
||
|
This operator initializes a tensor with random integers sampled from a
|
||
|
uniform distribution. The random result is in set [low, high).
|
||
|
)DOC");
|
||
|
AddAttr<std::vector<int64_t>>("shape", "The shape of the output tensor.")
|
||
|
.SetDefault({});
|
||
|
AddAttr<int>("low",
|
||
|
"The lower bound on the range of random values to generate.");
|
||
|
AddAttr<int>("high",
|
||
|
"The upper bound on the range of random values to generate.");
|
||
|
AddAttr<int>("dtype", "Output tensor data type. [Default INT64].")
|
||
|
.SetDefault(framework::proto::VarType::INT64);
|
||
|
}
|
||
|
};
|
||
|
|
||
|
} // namespace operators
|
||
|
} // namespace paddle
|
||
|
|
||
|
namespace ops = paddle::operators;
|
||
|
|
||
|
REGISTER_OPERATOR(
|
||
|
randint, ops::RandintOp, ops::RandintOpMaker,
|
||
|
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
|
||
|
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>)
|
||
|
|
||
|
REGISTER_OP_CPU_KERNEL(randint, ops::CPURandintKernel<int>,
|
||
|
ops::CPURandintKernel<int64_t>)
|