|
|
|
@ -21,8 +21,8 @@ namespace reader {
|
|
|
|
|
|
|
|
|
|
class CustomReader : public framework::DecoratedReader {
|
|
|
|
|
public:
|
|
|
|
|
CustomReader(ReaderBase* reader, const framework::BlockDesc& sub_block,
|
|
|
|
|
const framework::Scope& scope, const platform::Place& dev_place,
|
|
|
|
|
CustomReader(ReaderBase* reader, const framework::BlockDesc* sub_block,
|
|
|
|
|
const framework::Scope* scope, const platform::Place& dev_place,
|
|
|
|
|
const std::vector<std::string>& source_var_names,
|
|
|
|
|
const std::vector<std::string>& sink_var_names)
|
|
|
|
|
: DecoratedReader(reader),
|
|
|
|
@ -34,9 +34,15 @@ class CustomReader : public framework::DecoratedReader {
|
|
|
|
|
|
|
|
|
|
void ReadNext(std::vector<framework::LoDTensor>* out) override;
|
|
|
|
|
|
|
|
|
|
void UpdateBlockAndScope(const framework::BlockDesc* sub_block,
|
|
|
|
|
const framework::Scope* scope) {
|
|
|
|
|
sub_block_ = sub_block;
|
|
|
|
|
scope_ = scope;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
const framework::BlockDesc& sub_block_;
|
|
|
|
|
const framework::Scope& scope_;
|
|
|
|
|
const framework::BlockDesc* sub_block_;
|
|
|
|
|
const framework::Scope* scope_;
|
|
|
|
|
platform::Place dev_place_;
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> source_var_names_;
|
|
|
|
@ -52,14 +58,17 @@ class CreateCustomReaderOp : public framework::OperatorBase {
|
|
|
|
|
const platform::Place& dev_place) const override {
|
|
|
|
|
auto* out = scope.FindVar(Output("Out"))
|
|
|
|
|
->template GetMutable<framework::ReaderHolder>();
|
|
|
|
|
auto* sub_block = Attr<framework::BlockDesc*>("sub_block");
|
|
|
|
|
if (out->Get() != nullptr) {
|
|
|
|
|
auto* custom_reader = reinterpret_cast<CustomReader*>(out->Get());
|
|
|
|
|
custom_reader->UpdateBlockAndScope(sub_block, &scope);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
|
|
|
|
|
->Get<framework::ReaderHolder>();
|
|
|
|
|
out->Reset(new CustomReader(
|
|
|
|
|
underlying_reader.Get(), *Attr<framework::BlockDesc*>("sub_block"),
|
|
|
|
|
scope, dev_place, Attr<std::vector<std::string>>("source_var_names"),
|
|
|
|
|
out->Reset(
|
|
|
|
|
new CustomReader(underlying_reader.Get(), sub_block, &scope, dev_place,
|
|
|
|
|
Attr<std::vector<std::string>>("source_var_names"),
|
|
|
|
|
Attr<std::vector<std::string>>("sink_var_names")));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -141,31 +150,28 @@ void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) {
|
|
|
|
|
"the size of underlying_outs(%d) are not consistent. Each feeding "
|
|
|
|
|
"element must have its own source and sink variable.",
|
|
|
|
|
source_var_names_.size(), sink_var_names_.size(), underlying_outs.size());
|
|
|
|
|
|
|
|
|
|
framework::Scope* exe_scope = &scope_->NewScope();
|
|
|
|
|
// 1. Copy LoDTensors from underlying reader's output to source variables.
|
|
|
|
|
for (size_t i = 0; i < source_var_names_.size(); ++i) {
|
|
|
|
|
framework::Variable* var = scope_.FindVar(source_var_names_[i]);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
var, "CustomReader's source variable '%s' doesn't exist.");
|
|
|
|
|
framework::Variable* var = exe_scope->Var(source_var_names_[i]);
|
|
|
|
|
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
tensor->ShareDataWith(underlying_outs[i]);
|
|
|
|
|
tensor->set_lod(underlying_outs[i].lod());
|
|
|
|
|
}
|
|
|
|
|
// 2. Run the sub-block.
|
|
|
|
|
framework::Executor executor(dev_place_);
|
|
|
|
|
framework::ProgramDesc* program = sub_block_.Program();
|
|
|
|
|
framework::Scope* exe_scope = &scope_.NewScope();
|
|
|
|
|
executor.Run(*program, exe_scope, sub_block_.ID(), false, true);
|
|
|
|
|
scope_.DeleteScope(exe_scope);
|
|
|
|
|
framework::ProgramDesc* program = sub_block_->Program();
|
|
|
|
|
executor.Run(*program, exe_scope, sub_block_->ID(), false, true);
|
|
|
|
|
// 3. Copy LoDTensors from sink variables to out.
|
|
|
|
|
out->resize(sink_var_names_.size());
|
|
|
|
|
for (size_t i = 0; i < sink_var_names_.size(); ++i) {
|
|
|
|
|
framework::Variable* var = scope_.FindVar(sink_var_names_[i]);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var,
|
|
|
|
|
"CustomReader's sink variable '%s' doesn't exist.");
|
|
|
|
|
framework::Variable* var = exe_scope->FindVar(sink_var_names_[i]);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var);
|
|
|
|
|
const framework::LoDTensor& tensor = var->Get<framework::LoDTensor>();
|
|
|
|
|
(*out)[i].ShareDataWith(tensor);
|
|
|
|
|
(*out)[i].set_lod(tensor.lod());
|
|
|
|
|
framework::TensorCopySync(tensor, platform::CPUPlace(), &(*out)[i]);
|
|
|
|
|
}
|
|
|
|
|
scope_->DeleteScope(exe_scope);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace reader
|
|
|
|
|