|
|
|
@ -49,6 +49,10 @@ FileReaderMakerBase::FileReaderMakerBase(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
!ctx->IsRuntime(),
|
|
|
|
|
"'FileReaderInferShape' should only be invoked during compile time.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"The output file reader should not be null.");
|
|
|
|
|
const auto shape_concat = ctx->Attrs().Get<std::vector<int>>("shape_concat");
|
|
|
|
@ -56,16 +60,14 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
|
|
|
|
|
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
|
|
|
|
|
ctx->SetReaderDims("Out", shapes);
|
|
|
|
|
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
const auto lod_levels = ctx->Attrs().Get<std::vector<int>>("lod_levels");
|
|
|
|
|
PADDLE_ENFORCE_EQ(lod_levels.size(), shapes.size(),
|
|
|
|
|
"The number of 'lod_levels'(%d) doesn't match the number "
|
|
|
|
|
"of 'shapes'(%d).",
|
|
|
|
|
lod_levels.size(), shapes.size());
|
|
|
|
|
framework::VarDesc* reader =
|
|
|
|
|
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
|
|
|
|
|
reader->SetLoDLevels(lod_levels);
|
|
|
|
|
}
|
|
|
|
|
const auto lod_levels = ctx->Attrs().Get<std::vector<int>>("lod_levels");
|
|
|
|
|
PADDLE_ENFORCE_EQ(lod_levels.size(), shapes.size(),
|
|
|
|
|
"The number of 'lod_levels'(%d) doesn't match the number "
|
|
|
|
|
"of 'shapes'(%d).",
|
|
|
|
|
lod_levels.size(), shapes.size());
|
|
|
|
|
framework::VarDesc* reader =
|
|
|
|
|
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
|
|
|
|
|
reader->SetLoDLevels(lod_levels);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc,
|
|
|
|
@ -77,19 +79,21 @@ void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc,
|
|
|
|
|
|
|
|
|
|
void DecoratedReaderInferShape::operator()(
|
|
|
|
|
framework::InferShapeContext* ctx) const {
|
|
|
|
|
PADDLE_ENFORCE(!ctx->IsRuntime(),
|
|
|
|
|
"'DecoratedReaderInferShape' should only be invoked during "
|
|
|
|
|
"compile time.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("UnderlyingReader"),
|
|
|
|
|
"Input(UnderlyingReader) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"The output decorated reader should not be null.");
|
|
|
|
|
ctx->SetReaderDims("Out", ctx->GetReaderDims("UnderlyingReader"));
|
|
|
|
|
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
framework::VarDesc* in_reader = boost::get<framework::VarDesc*>(
|
|
|
|
|
ctx->GetInputVarPtrs("UnderlyingReader")[0]);
|
|
|
|
|
framework::VarDesc* out_reader =
|
|
|
|
|
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
|
|
|
|
|
out_reader->SetLoDLevels(in_reader->GetLoDLevels());
|
|
|
|
|
}
|
|
|
|
|
framework::VarDesc* in_reader = boost::get<framework::VarDesc*>(
|
|
|
|
|
ctx->GetInputVarPtrs("UnderlyingReader")[0]);
|
|
|
|
|
framework::VarDesc* out_reader =
|
|
|
|
|
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
|
|
|
|
|
out_reader->SetLoDLevels(in_reader->GetLoDLevels());
|
|
|
|
|
}
|
|
|
|
|
void DecoratedReaderInferVarType::operator()(
|
|
|
|
|
const framework::OpDesc& op_desc, framework::BlockDesc* block) const {
|
|
|
|
|