|
|
@ -18,7 +18,7 @@
|
|
|
|
namespace paddle {
|
|
|
|
namespace paddle {
|
|
|
|
namespace operators {
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
|
|
// general infershape
|
|
|
|
// general infershape for file readers
|
|
|
|
class CreateReaderInferShape : public framework::InferShapeBase {
|
|
|
|
class CreateReaderInferShape : public framework::InferShapeBase {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void operator()(framework::InferShapeContext* ctx) const override {
|
|
|
|
void operator()(framework::InferShapeContext* ctx) const override {
|
|
|
@ -35,6 +35,7 @@ class CreateRandomReaderOp : public framework::OperatorBase {
|
|
|
|
const platform::Place& dev_place) const override {
|
|
|
|
const platform::Place& dev_place) const override {
|
|
|
|
const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
|
|
|
|
const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
|
|
|
|
const auto& ranks = Attr<std::vector<int>>("ranks");
|
|
|
|
const auto& ranks = Attr<std::vector<int>>("ranks");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
|
|
|
|
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
|
|
|
|
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
|
|
|
|
int(shape_concat.size()),
|
|
|
|
int(shape_concat.size()),
|
|
|
|
"The accumulate of all ranks should be equal to the "
|
|
|
|
"The accumulate of all ranks should be equal to the "
|
|
|
@ -49,8 +50,9 @@ class CreateRandomReaderOp : public framework::OperatorBase {
|
|
|
|
offset += len;
|
|
|
|
offset += len;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto* out = scope.FindVar(Output("Out"))
|
|
|
|
auto* out = scope.FindVar(Output("Out"))
|
|
|
|
->template GetMutable<framework::RandomReader<T>>();
|
|
|
|
->template GetMutable<framework::ReaderHolder>();
|
|
|
|
out->Initialize(shapes, Attr<float>("min"), Attr<float>("max"));
|
|
|
|
out->Reset(new framework::RandomReader<T>(shapes, Attr<float>("min"),
|
|
|
|
|
|
|
|
Attr<float>("max")));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
@ -58,7 +60,7 @@ class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
CreateRandomReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
|
|
|
|
CreateRandomReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
|
|
|
|
: OpProtoAndCheckerMaker(op_proto, op_checker) {
|
|
|
|
: OpProtoAndCheckerMaker(op_proto, op_checker) {
|
|
|
|
AddOutput("Out", "(RandomReader) The created random reader.");
|
|
|
|
AddOutput("Out", "(ReaderHolder) The created random reader.");
|
|
|
|
AddAttr<std::vector<int>>("shape_concat",
|
|
|
|
AddAttr<std::vector<int>>("shape_concat",
|
|
|
|
"The concat of all data's shapes.");
|
|
|
|
"The concat of all data's shapes.");
|
|
|
|
AddAttr<std::vector<int>>(
|
|
|
|
AddAttr<std::vector<int>>(
|
|
|
@ -81,10 +83,57 @@ class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CreateShuffleReaderInferShape : public framework::InferShapeBase {
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
|
|
|
void operator()(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Underlying_reader"),
|
|
|
|
|
|
|
|
"Input(Underlying_reader) of CreateShuffleReaderOp should "
|
|
|
|
|
|
|
|
"not be null.");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
|
|
|
"Output(Out) of CreateShuffleReaderOp should not be null.");
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CreateShuffleReaderOp : public framework::OperatorBase {
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
|
|
|
using framework::OperatorBase::OperatorBase;
|
|
|
|
|
|
|
|
void Run(const framework::Scope& scope,
|
|
|
|
|
|
|
|
const platform::Place& dev_place) const override {
|
|
|
|
|
|
|
|
const auto& underlying_reader = scope.FindVar(Input("Underlying_reader"))
|
|
|
|
|
|
|
|
->Get<framework::ReaderHolder>();
|
|
|
|
|
|
|
|
auto* out = scope.FindVar(Output("Out"))
|
|
|
|
|
|
|
|
->template GetMutable<framework::ReaderHolder>();
|
|
|
|
|
|
|
|
out->Reset(new framework::ShuffleReader(underlying_reader.Get(),
|
|
|
|
|
|
|
|
Attr<int>("buffer_size")));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
|
|
|
CreateShuffleReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
|
|
|
|
|
|
|
|
: OpProtoAndCheckerMaker(op_proto, op_checker) {
|
|
|
|
|
|
|
|
AddInput(
|
|
|
|
|
|
|
|
"Underlying_reader",
|
|
|
|
|
|
|
|
"(ReaderHolder) The underlying reader for creating a shuffle reader.");
|
|
|
|
|
|
|
|
AddOutput("Out", "(ReaderHolder) The created shuffle reader.");
|
|
|
|
|
|
|
|
AddAttr<int>("buffer_size", "The shuffle buffer size.").GreaterThan(0);
|
|
|
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
|
|
|
CreateShuffleReader Operator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
A shuffle reader takes another reader as its 'underlying reader'
|
|
|
|
|
|
|
|
and output the underlying reader's outputs in a shuffled order.
|
|
|
|
|
|
|
|
)DOC");
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
} // namespace operators
|
|
|
|
} // namespace paddle
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
REGISTER_OPERATOR(create_random_reader, ops::CreateRandomReaderOp<float>,
|
|
|
|
REGISTER_OPERATOR(create_random_reader, ops::CreateRandomReaderOp<float>,
|
|
|
|
ops::CreateReaderInferShape, ops::CreateRandomReaderOpMaker,
|
|
|
|
ops::CreateReaderInferShape, ops::CreateRandomReaderOpMaker,
|
|
|
|
paddle::framework::EmptyGradOpMaker);
|
|
|
|
paddle::framework::EmptyGradOpMaker);
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp,
|
|
|
|
|
|
|
|
ops::CreateShuffleReaderInferShape,
|
|
|
|
|
|
|
|
ops::CreateShuffleReaderOpMaker,
|
|
|
|
|
|
|
|
paddle::framework::EmptyGradOpMaker);
|
|
|
|