|
|
|
@ -18,8 +18,8 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
std::vector<framework::DDim> RestoreShapes(const std::vector<int>& shape_concat,
|
|
|
|
|
const std::vector<int>& ranks) {
|
|
|
|
|
static std::vector<framework::DDim> RestoreShapes(
|
|
|
|
|
const std::vector<int>& shape_concat, const std::vector<int>& ranks) {
|
|
|
|
|
std::vector<framework::DDim> res;
|
|
|
|
|
int offset = 0;
|
|
|
|
|
for (int len : ranks) {
|
|
|
|
@ -69,7 +69,7 @@ class CreateReaderInferVarType : public framework::VarTypeInference {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class CreateRandomReaderOp : public framework::OperatorBase {
|
|
|
|
|
class CreateRandomDataGeneratorOp : public framework::OperatorBase {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorBase::OperatorBase;
|
|
|
|
|
void Run(const framework::Scope& scope,
|
|
|
|
@ -84,14 +84,15 @@ class CreateRandomReaderOp : public framework::OperatorBase {
|
|
|
|
|
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
|
|
|
|
|
auto* out = scope.FindVar(Output("Out"))
|
|
|
|
|
->template GetMutable<framework::ReaderHolder>();
|
|
|
|
|
out->Reset(new framework::RandomReader<T>(shapes, Attr<float>("min"),
|
|
|
|
|
Attr<float>("max")));
|
|
|
|
|
out->Reset(new framework::RandomDataGenerator<T>(shapes, Attr<float>("min"),
|
|
|
|
|
Attr<float>("max")));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
class CreateRandomDataGeneratorOpMaker
|
|
|
|
|
: public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
CreateRandomReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
|
|
|
|
|
CreateRandomDataGeneratorOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(op_proto, op_checker) {
|
|
|
|
|
AddOutput("Out", "(ReaderHolder) The created random reader.");
|
|
|
|
|
AddAttr<std::vector<int>>("shape_concat",
|
|
|
|
@ -107,7 +108,7 @@ class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddAttr<float>("min", "The lower bound of reader's uniform distribution.");
|
|
|
|
|
AddAttr<float>("max", "The upper bound of reader's uniform distribution.");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
CreateRandomReader Operator
|
|
|
|
|
CreateRandomDataGenerator Operator
|
|
|
|
|
|
|
|
|
|
This Op creates a random reader.
|
|
|
|
|
The reader generates random data instead of really reading from files.
|
|
|
|
@ -186,9 +187,10 @@ class CreateBatchReaderOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OPERATOR(create_random_reader, ops::CreateRandomReaderOp<float>,
|
|
|
|
|
REGISTER_OPERATOR(create_random_data_generator,
|
|
|
|
|
ops::CreateRandomDataGeneratorOp<float>,
|
|
|
|
|
ops::CreateFileReaderInferShape,
|
|
|
|
|
ops::CreateRandomReaderOpMaker,
|
|
|
|
|
ops::CreateRandomDataGeneratorOpMaker,
|
|
|
|
|
paddle::framework::EmptyGradOpMaker,
|
|
|
|
|
ops::CreateReaderInferVarType);
|
|
|
|
|
REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp,
|
|
|
|
|