fix a potential bug in the c++ reader

shanyi15-patch-2
fengjiayi 7 years ago
parent ccc5418841
commit 614c33fb3a

@ -65,12 +65,25 @@ class ReaderHolder {
ReaderBase* Get() const { return reader_.get(); }
void ReadNext(std::vector<LoDTensor>* out) { reader_->ReadNext(out); }
void ReInit() { reader_->ReInit(); }
void ReadNext(std::vector<LoDTensor>* out) {
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->ReadNext(out);
}
void ReInit() {
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->ReInit();
}
DDim shape(size_t idx) const { return reader_->shape(idx); }
std::vector<DDim> shapes() const { return reader_->shapes(); }
DDim shape(size_t idx) const {
PADDLE_ENFORCE_NOT_NULL(reader_);
return reader_->shape(idx);
}
std::vector<DDim> shapes() const {
PADDLE_ENFORCE_NOT_NULL(reader_);
return reader_->shapes();
}
void set_shapes(const std::vector<DDim>& shapes) {
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->set_shapes(shapes);
}

@ -49,6 +49,10 @@ FileReaderMakerBase::FileReaderMakerBase(
}
void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(
!ctx->IsRuntime(),
"'FileReaderInferShape' should only be invoked during compile time.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"The output file reader should not be null.");
const auto shape_concat = ctx->Attrs().Get<std::vector<int>>("shape_concat");
@ -56,16 +60,14 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
ctx->SetReaderDims("Out", shapes);
if (ctx->IsRuntime()) {
const auto lod_levels = ctx->Attrs().Get<std::vector<int>>("lod_levels");
PADDLE_ENFORCE_EQ(lod_levels.size(), shapes.size(),
"The number of 'lod_levels'(%d) doesn't match the number "
"of 'shapes'(%d).",
lod_levels.size(), shapes.size());
framework::VarDesc* reader =
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
reader->SetLoDLevels(lod_levels);
}
const auto lod_levels = ctx->Attrs().Get<std::vector<int>>("lod_levels");
PADDLE_ENFORCE_EQ(lod_levels.size(), shapes.size(),
"The number of 'lod_levels'(%d) doesn't match the number "
"of 'shapes'(%d).",
lod_levels.size(), shapes.size());
framework::VarDesc* reader =
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
reader->SetLoDLevels(lod_levels);
}
void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc,
@ -77,19 +79,21 @@ void FileReaderInferVarType::operator()(const framework::OpDesc& op_desc,
void DecoratedReaderInferShape::operator()(
framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(!ctx->IsRuntime(),
"'DecoratedReaderInferShape' should only be invoked during "
"compile time.");
PADDLE_ENFORCE(ctx->HasInput("UnderlyingReader"),
"Input(UnderlyingReader) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"The output decorated reader should not be null.");
ctx->SetReaderDims("Out", ctx->GetReaderDims("UnderlyingReader"));
if (ctx->IsRuntime()) {
framework::VarDesc* in_reader = boost::get<framework::VarDesc*>(
ctx->GetInputVarPtrs("UnderlyingReader")[0]);
framework::VarDesc* out_reader =
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
out_reader->SetLoDLevels(in_reader->GetLoDLevels());
}
framework::VarDesc* in_reader = boost::get<framework::VarDesc*>(
ctx->GetInputVarPtrs("UnderlyingReader")[0]);
framework::VarDesc* out_reader =
boost::get<framework::VarDesc*>(ctx->GetOutputVarPtrs("Out")[0]);
out_reader->SetLoDLevels(in_reader->GetLoDLevels());
}
void DecoratedReaderInferVarType::operator()(
const framework::OpDesc& op_desc, framework::BlockDesc* block) const {

Loading…
Cancel
Save