|
|
|
@ -66,33 +66,12 @@ platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) {
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// framework::BlockDesc* InferShapeAndVarType(OpBase* op, const VarBasePtrMap&
|
|
|
|
|
// inputs, const VarBasePtrMap& outputs) {
|
|
|
|
|
// std::unique_ptr<BlockDesc> block(new BlockDesc());
|
|
|
|
|
|
|
|
|
|
// // construct op desc
|
|
|
|
|
// op->op_desc_ = block.AppendOp();
|
|
|
|
|
|
|
|
|
|
// // construct op inputs and outputs
|
|
|
|
|
// // for
|
|
|
|
|
// //
|
|
|
|
|
// for (auto it = )
|
|
|
|
|
// op->op_desc_->SetInput()
|
|
|
|
|
|
|
|
|
|
// op->op_desc_->InferShape(*block);
|
|
|
|
|
// op->op_desc_->InferVarType(block.get());
|
|
|
|
|
|
|
|
|
|
// return block.release();
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
|
|
|
|
|
const VarBasePtrMap& outputs, framework::BlockDesc* block,
|
|
|
|
|
const platform::Place expected_place,
|
|
|
|
|
const bool stop_gradient) {
|
|
|
|
|
std::map<std::string, VarBase*> vars;
|
|
|
|
|
|
|
|
|
|
// framework::BlockDesc* block = InferShapeAndVarType(op, inputs, outputs);
|
|
|
|
|
|
|
|
|
|
framework::OpDesc* op_desc = op->op_desc_;
|
|
|
|
|
VLOG(3) << "tracer tracing " << op_desc->Type();
|
|
|
|
|
op_desc->InferShape(*block);
|
|
|
|
|