|
|
|
@ -23,14 +23,14 @@ namespace operators {
|
|
|
|
|
template <typename T>
|
|
|
|
|
class CPUUniformRandomKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
framework::Tensor* tensor = nullptr;
|
|
|
|
|
void Compute(const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
framework::Tensor *tensor = nullptr;
|
|
|
|
|
auto out_var = ctx.OutputVar("Out");
|
|
|
|
|
if (out_var->IsType<framework::LoDTensor>()) {
|
|
|
|
|
tensor = out_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
} else if (out_var->IsType<framework::SelectedRows>()) {
|
|
|
|
|
auto shape = ctx.Attr<std::vector<int>>("shape");
|
|
|
|
|
auto* selected_rows = out_var->GetMutable<framework::SelectedRows>();
|
|
|
|
|
auto *selected_rows = out_var->GetMutable<framework::SelectedRows>();
|
|
|
|
|
tensor = selected_rows->mutable_value();
|
|
|
|
|
tensor->Resize(framework::make_ddim(shape));
|
|
|
|
|
selected_rows->mutable_rows()->reserve(shape[0]);
|
|
|
|
@ -39,7 +39,7 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> {
|
|
|
|
|
"uniform_random_op's output only"
|
|
|
|
|
"supports SelectedRows and LoDTensor");
|
|
|
|
|
}
|
|
|
|
|
T* data = tensor->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
T *data = tensor->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
|
|
|
|
|
std::minstd_rand engine;
|
|
|
|
|
if (seed == 0) {
|
|
|
|
@ -60,14 +60,14 @@ class UniformRandomOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of UniformRandomOp should not be null.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->Attrs().Get<float>("min") < ctx->Attrs().Get<float>("max"),
|
|
|
|
|
"uniform_random's min must less then max");
|
|
|
|
|
auto& shape = ctx->Attrs().Get<std::vector<int>>("shape");
|
|
|
|
|
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape");
|
|
|
|
|
std::vector<int64_t> temp;
|
|
|
|
|
temp.reserve(shape.size());
|
|
|
|
|
for (auto dim : shape) {
|
|
|
|
@ -78,7 +78,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
return framework::OpKernelType(
|
|
|
|
|
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
|
|
|
|
|
ctx.GetPlace());
|
|
|
|
@ -112,17 +112,17 @@ uniform distribution. The random result is in set [min, max].
|
|
|
|
|
|
|
|
|
|
class UniformRandomOpVarTypeInference : public framework::VarTypeInference {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const framework::OpDesc& op_desc,
|
|
|
|
|
framework::BlockDesc* block) const override {
|
|
|
|
|
void operator()(const framework::OpDesc &op_desc,
|
|
|
|
|
framework::BlockDesc *block) const override {
|
|
|
|
|
auto out_var_name = op_desc.Output("Out").front();
|
|
|
|
|
if (block->FindRecursiveOrCreateVar(out_var_name).GetType() ==
|
|
|
|
|
framework::proto::VarType::SELECTED_ROWS) {
|
|
|
|
|
block->FindRecursiveOrCreateVar(out_var_name)
|
|
|
|
|
.SetType(framework::proto::VarType::SELECTED_ROWS);
|
|
|
|
|
} else {
|
|
|
|
|
block->FindRecursiveOrCreateVar(out_var_name)
|
|
|
|
|
.SetType(framework::proto::VarType::LOD_TENSOR);
|
|
|
|
|
auto var_data_type = static_cast<framework::proto::VarType::Type>(
|
|
|
|
|
boost::get<int>(op_desc.GetAttr("dtype")));
|
|
|
|
|
|
|
|
|
|
auto out_var = block->FindRecursiveOrCreateVar(out_var_name);
|
|
|
|
|
if (out_var.GetType() != framework::proto::VarType::SELECTED_ROWS) {
|
|
|
|
|
out_var.SetType(framework::proto::VarType::LOD_TENSOR);
|
|
|
|
|
}
|
|
|
|
|
out_var.SetDataType(var_data_type);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|