|
|
@ -46,9 +46,11 @@ class ElementwiseOpInferVarType : public framework::VarTypeInference {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void operator()(const framework::OpDesc& op_desc,
|
|
|
|
void operator()(const framework::OpDesc& op_desc,
|
|
|
|
framework::BlockDesc* block) const override {
|
|
|
|
framework::BlockDesc* block) const override {
|
|
|
|
auto x_var = op_desc.Input("X")[0];
|
|
|
|
auto x_name = op_desc.Input("X")[0];
|
|
|
|
auto out_var = op_desc.Output("Out")[0];
|
|
|
|
auto out_name = op_desc.Output("Out")[0];
|
|
|
|
block->Var(out_var)->SetType(block->Var(x_var)->GetType());
|
|
|
|
auto& x = block->FindRecursiveOrCreateVar(x_name);
|
|
|
|
|
|
|
|
auto& out = block->FindRecursiveOrCreateVar(out_name);
|
|
|
|
|
|
|
|
out.SetType(x.GetType());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|