|
|
|
@ -62,12 +62,14 @@ void FileReaderMakerBase::Make() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
!ctx->IsRuntime(),
|
|
|
|
|
"'FileReaderInferShape' should only be invoked during compile time.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"The output file reader should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
ctx->IsRuntime(), true,
|
|
|
|
|
platform::errors::PreconditionNotMet("'FileReaderInferShape' should only "
|
|
|
|
|
"be invoked during compile time."));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasOutput("Out"), true,
|
|
|
|
|
platform::errors::NotFound("The output file reader should not be null."));
|
|
|
|
|
bool use_data_config = ctx->Attrs().Get<bool>("use_data_config");
|
|
|
|
|
if (use_data_config) {
|
|
|
|
|
const auto shape_concat =
|
|
|
|
@ -77,21 +79,26 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
|
|
|
|
|
ctx->SetReaderDims("Out", shapes);
|
|
|
|
|
|
|
|
|
|
const auto lod_levels = ctx->Attrs().Get<std::vector<int>>("lod_levels");
|
|
|
|
|
PADDLE_ENFORCE_EQ(lod_levels.size(), shapes.size(),
|
|
|
|
|
"The number of 'lod_levels'(%d) doesn't match the number "
|
|
|
|
|
"of 'shapes'(%d).",
|
|
|
|
|
lod_levels.size(), shapes.size());
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
lod_levels.size(), shapes.size(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The number of 'lod_levels'(%d) doesn't match the number "
|
|
|
|
|
"of 'shapes'(%d).",
|
|
|
|
|
lod_levels.size(), shapes.size()));
|
|
|
|
|
const auto dtypes = ctx->Attrs().Get<std::vector<int>>("dtypes");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dtypes.size(), shapes.size(),
|
|
|
|
|
"The number of 'dtypes'(%d) doesn't match the number of 'shapes'(%d).",
|
|
|
|
|
dtypes.size(), shapes.size());
|
|
|
|
|
platform::errors::InvalidArgument("The number of 'dtypes'(%d) doesn't "
|
|
|
|
|
"match the number of 'shapes'(%d).",
|
|
|
|
|
dtypes.size(), shapes.size()));
|
|
|
|
|
const auto need_check_feed =
|
|
|
|
|
ctx->Attrs().Get<std::vector<int>>("need_check_feed");
|
|
|
|
|
PADDLE_ENFORCE_EQ(need_check_feed.size(), shapes.size(),
|
|
|
|
|
"The number of 'need_check_feed'(%d) doesn't match the "
|
|
|
|
|
"number of 'shapes'(%d).",
|
|
|
|
|
need_check_feed.size(), shapes.size());
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
need_check_feed.size(), shapes.size(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The number of 'need_check_feed'(%d) doesn't match the "
|
|
|
|
|
"number of 'shapes'(%d).",
|
|
|
|
|
need_check_feed.size(), shapes.size()));
|
|
|
|
|
framework::VarDesc* reader =
|
|
|
|
|
BOOST_GET(framework::VarDesc*, ctx->GetOutputVarPtrs("Out")[0]);
|
|
|
|
|
reader->SetLoDLevels(lod_levels);
|
|
|
|
@ -105,14 +112,18 @@ void FileReaderInferVarType::operator()(
|
|
|
|
|
|
|
|
|
|
void DecoratedReaderInferShape::operator()(
|
|
|
|
|
framework::InferShapeContext* ctx) const {
|
|
|
|
|
PADDLE_ENFORCE(!ctx->IsRuntime(),
|
|
|
|
|
"'DecoratedReaderInferShape' should only be invoked during "
|
|
|
|
|
"compile time.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("UnderlyingReader"),
|
|
|
|
|
"Input(UnderlyingReader) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"The output decorated reader should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
ctx->IsRuntime(), true,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"'DecoratedReaderInferShape' should only be invoked during "
|
|
|
|
|
"compile time."));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("UnderlyingReader"), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Input(UnderlyingReader) should not be null."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"The output decorated reader should not be null."));
|
|
|
|
|
ctx->SetReaderDims("Out", ctx->GetReaderDims("UnderlyingReader"));
|
|
|
|
|
|
|
|
|
|
framework::VarDesc* in_reader = BOOST_GET(
|
|
|
|
|