|
|
|
@ -35,14 +35,17 @@ class SumOpVarTypeInference : public VarTypeInference {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const OpDescBind &op_desc,
|
|
|
|
|
BlockDescBind *block) const override {
|
|
|
|
|
auto default_var_type = VarDesc::LOD_TENSOR;
|
|
|
|
|
for (auto &in_var_name : op_desc.Input("X")) {
|
|
|
|
|
auto in_var_type = block->Var(in_var_name)->GetType();
|
|
|
|
|
if (in_var_type != default_var_type) {
|
|
|
|
|
default_var_type = in_var_type;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
auto &inputs = op_desc.Input("X");
|
|
|
|
|
auto default_var_type = VarDesc::SELECTED_ROWS;
|
|
|
|
|
|
|
|
|
|
bool any_input_is_lod_tensor = std::any_of(
|
|
|
|
|
inputs.begin(), inputs.end(), [block](const std::string &name) {
|
|
|
|
|
return block->Var(name)->GetType() == VarDesc::LOD_TENSOR;
|
|
|
|
|
});
|
|
|
|
|
if (any_input_is_lod_tensor) {
|
|
|
|
|
default_var_type = VarDesc::LOD_TENSOR;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto out_var_name = op_desc.Output("Out").front();
|
|
|
|
|
block->Var(out_var_name)->SetType(default_var_type);
|
|
|
|
|
}
|
|
|
|
@ -65,20 +68,18 @@ TEST(InferVarType, sum_op) {
|
|
|
|
|
op->SetInput("X", {"test_a", "test_b", "test_c"});
|
|
|
|
|
op->SetOutput("Out", {"test_out"});
|
|
|
|
|
|
|
|
|
|
prog.Block(0)->NewVar("test_a")->SetType(VarDesc_VarType_LOD_TENSOR);
|
|
|
|
|
prog.Block(0)->NewVar("test_b")->SetType(VarDesc_VarType_LOD_TENSOR);
|
|
|
|
|
prog.Block(0)->NewVar("test_c")->SetType(VarDesc_VarType_LOD_TENSOR);
|
|
|
|
|
prog.Block(0)->NewVar("test_a")->SetType(VarDesc::SELECTED_ROWS);
|
|
|
|
|
prog.Block(0)->NewVar("test_b")->SetType(VarDesc::SELECTED_ROWS);
|
|
|
|
|
prog.Block(0)->NewVar("test_c")->SetType(VarDesc::SELECTED_ROWS);
|
|
|
|
|
prog.Block(0)->NewVar("test_out");
|
|
|
|
|
|
|
|
|
|
op->InferVarType(prog.Block(0));
|
|
|
|
|
|
|
|
|
|
ASSERT_EQ(VarDesc_VarType_LOD_TENSOR,
|
|
|
|
|
prog.Block(0)->Var("test_out")->GetType());
|
|
|
|
|
ASSERT_EQ(VarDesc::SELECTED_ROWS, prog.Block(0)->Var("test_out")->GetType());
|
|
|
|
|
|
|
|
|
|
prog.Block(0)->Var("test_b")->SetType(VarDesc_VarType_SELECTED_ROWS);
|
|
|
|
|
prog.Block(0)->Var("test_b")->SetType(VarDesc::LOD_TENSOR);
|
|
|
|
|
op->InferVarType(prog.Block(0));
|
|
|
|
|
ASSERT_EQ(VarDesc_VarType_SELECTED_ROWS,
|
|
|
|
|
prog.Block(0)->Var("test_out")->GetType());
|
|
|
|
|
ASSERT_EQ(VarDesc::LOD_TENSOR, prog.Block(0)->Var("test_out")->GetType());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(InferVarType, sum_op_without_infer_var_type) {
|
|
|
|
@ -88,9 +89,9 @@ TEST(InferVarType, sum_op_without_infer_var_type) {
|
|
|
|
|
op->SetInput("X", {"test2_a", "test2_b", "test2_c"});
|
|
|
|
|
op->SetOutput("Out", {"test2_out"});
|
|
|
|
|
|
|
|
|
|
prog.Block(0)->NewVar("test2_a")->SetType(VarDesc_VarType_LOD_TENSOR);
|
|
|
|
|
prog.Block(0)->NewVar("test2_b")->SetType(VarDesc_VarType_SELECTED_ROWS);
|
|
|
|
|
prog.Block(0)->NewVar("test2_c")->SetType(VarDesc_VarType_LOD_TENSOR);
|
|
|
|
|
prog.Block(0)->NewVar("test2_a")->SetType(VarDesc::SELECTED_ROWS);
|
|
|
|
|
prog.Block(0)->NewVar("test2_b")->SetType(VarDesc::SELECTED_ROWS);
|
|
|
|
|
prog.Block(0)->NewVar("test2_c")->SetType(VarDesc::SELECTED_ROWS);
|
|
|
|
|
prog.Block(0)->NewVar("test2_out");
|
|
|
|
|
|
|
|
|
|
op->InferVarType(prog.Block(0));
|
|
|
|
|