|
|
|
@ -161,7 +161,7 @@ Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
|
|
|
|
|
VarBasePtrMap& outputs,
|
|
|
|
|
VarBasePtrMap* outputs,
|
|
|
|
|
framework::AttributeMap attrs_map,
|
|
|
|
|
const platform::Place expected_place,
|
|
|
|
|
const bool stop_gradient) {
|
|
|
|
@ -195,7 +195,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
op->output_vars_ = outputs;
|
|
|
|
|
op->output_vars_ = *outputs;
|
|
|
|
|
for (auto it : op->output_vars_) {
|
|
|
|
|
auto& outvars = outvars_map[it.first];
|
|
|
|
|
const std::vector<VarBase*>& outputs = it.second;
|
|
|
|
@ -218,7 +218,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
|
|
|
|
|
framework::VariableNameMap invars_name_map =
|
|
|
|
|
CreateInputVarNameMap(op, inputs);
|
|
|
|
|
framework::VariableNameMap outvars_name_map =
|
|
|
|
|
CreateOutputVarNameMap(op, outputs);
|
|
|
|
|
CreateOutputVarNameMap(op, *outputs);
|
|
|
|
|
|
|
|
|
|
auto& info = framework::OpInfoMap::Instance().Get(op->Type());
|
|
|
|
|
if (info.Checker() != nullptr) {
|
|
|
|
@ -230,8 +230,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
|
|
|
|
|
outvars_name_map, attrs_map);
|
|
|
|
|
|
|
|
|
|
if (info.infer_var_type_) {
|
|
|
|
|
RuntimeInferVarTypeContext infer_var_type_ctx(&inputs, &outputs,
|
|
|
|
|
&attrs_map);
|
|
|
|
|
RuntimeInferVarTypeContext infer_var_type_ctx(&inputs, outputs, &attrs_map);
|
|
|
|
|
info.infer_var_type_(&infer_var_type_ctx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|