|
|
|
@ -42,6 +42,18 @@ class CreateFileReaderInferShape : public framework::InferShapeBase {
|
|
|
|
|
const auto ranks = ctx->Attrs().Get<std::vector<int>>("ranks");
|
|
|
|
|
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
|
|
|
|
|
ctx->SetReaderDims("Out", shapes);
|
|
|
|
|
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
const auto lod_levels = ctx->Attrs().Get<std::vector<int>>("lod_levels");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
lod_levels.size(), shapes.size(),
|
|
|
|
|
"The number of 'lod_levels'(%d) doesn't match the number "
|
|
|
|
|
"of 'shapes'(%d).",
|
|
|
|
|
lod_levels.size(), shapes.size());
|
|
|
|
|
framework::VarDesc* reader =
|
|
|
|
|
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
|
|
|
|
|
reader->SetLoDLevels(lod_levels);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -54,11 +66,19 @@ class CreateDecoratedReaderInferShape : public framework::InferShapeBase {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"The output decorated reader should not be null.");
|
|
|
|
|
ctx->SetReaderDims("Out", ctx->GetReaderDims("UnderlyingReader"));
|
|
|
|
|
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
framework::VarDesc* in_reader = boost::get<framework::VarDesc*>(
|
|
|
|
|
ctx->GetInputVarPtrs("UnderlyingReader")[0]);
|
|
|
|
|
framework::VarDesc* out_reader =
|
|
|
|
|
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
|
|
|
|
|
out_reader->SetLoDLevels(in_reader->GetLoDLevels());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// general var type inference for all readers
|
|
|
|
|
class CreateReaderInferVarType : public framework::VarTypeInference {
|
|
|
|
|
// general var type inference for file readers
|
|
|
|
|
class CreateFileReaderInferVarType : public framework::VarTypeInference {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const framework::OpDesc& op_desc,
|
|
|
|
|
framework::BlockDesc* block) const override {
|
|
|
|
@ -68,6 +88,20 @@ class CreateReaderInferVarType : public framework::VarTypeInference {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// general var type inference for decorated readers
|
|
|
|
|
class CreateDecoratedReaderInferVarType : public framework::VarTypeInference {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const framework::OpDesc& op_desc,
|
|
|
|
|
framework::BlockDesc* block) const override {
|
|
|
|
|
std::string in_reader_name = op_desc.Input("UnderlyingReader")[0];
|
|
|
|
|
framework::VarDesc* in_reader = block->FindVarRecursive(in_reader_name);
|
|
|
|
|
std::string out_reader_name = op_desc.Output("Out")[0];
|
|
|
|
|
framework::VarDesc* out_reader = block->FindVarRecursive(out_reader_name);
|
|
|
|
|
out_reader->SetType(framework::proto::VarDesc::READER);
|
|
|
|
|
out_reader->SetDataTypes(in_reader->GetDataTypes());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class CreateRandomDataGeneratorOp : public framework::OperatorBase {
|
|
|
|
|
public:
|
|
|
|
@ -105,6 +139,7 @@ class CreateRandomDataGeneratorOpMaker
|
|
|
|
|
"ranks = [3,2]"
|
|
|
|
|
"It means the reader will generate two data each time,"
|
|
|
|
|
"whose shapes are [2,3,4] and [5,6] respectively.");
|
|
|
|
|
AddAttr<std::vector<int>>("lod_levels", "The LoD levels of each data.");
|
|
|
|
|
AddAttr<float>("min", "The lower bound of reader's uniform distribution.");
|
|
|
|
|
AddAttr<float>("max", "The upper bound of reader's uniform distribution.");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
@ -192,14 +227,14 @@ REGISTER_OPERATOR(create_random_data_generator,
|
|
|
|
|
ops::CreateFileReaderInferShape,
|
|
|
|
|
ops::CreateRandomDataGeneratorOpMaker,
|
|
|
|
|
paddle::framework::EmptyGradOpMaker,
|
|
|
|
|
ops::CreateReaderInferVarType);
|
|
|
|
|
ops::CreateFileReaderInferVarType);
|
|
|
|
|
REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp,
|
|
|
|
|
ops::CreateDecoratedReaderInferShape,
|
|
|
|
|
ops::CreateShuffleReaderOpMaker,
|
|
|
|
|
paddle::framework::EmptyGradOpMaker,
|
|
|
|
|
ops::CreateReaderInferVarType);
|
|
|
|
|
ops::CreateDecoratedReaderInferVarType);
|
|
|
|
|
REGISTER_OPERATOR(create_batch_reader, ops::CreateBatchReaderOp,
|
|
|
|
|
ops::CreateDecoratedReaderInferShape,
|
|
|
|
|
ops::CreateBatchReaderOpMaker,
|
|
|
|
|
paddle::framework::EmptyGradOpMaker,
|
|
|
|
|
ops::CreateReaderInferVarType);
|
|
|
|
|
ops::CreateDecoratedReaderInferVarType);
|
|
|
|
|