!6946 MSLITE fix compare function output dataType

Merge pull request !6946 from 徐安越/master
pull/6946/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit afef3ec6f6

@ -35,7 +35,7 @@ int Equal::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeUInt8);
output->set_data_type(TypeId::kNumberTypeBool);
output->SetFormat(input->GetFormat());
return RET_OK;
}

@ -29,5 +29,15 @@ int Greater::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers
return RET_OK;
}
#endif
int Greater::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -35,6 +35,7 @@ class Greater : public Arithmetic {
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore

@ -28,5 +28,15 @@ int GreaterEqual::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbu
return RET_OK;
}
#endif
int GreaterEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -36,6 +36,7 @@ class GreaterEqual : public Arithmetic {
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore

@ -30,5 +30,15 @@ int Less::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F
return RET_OK;
}
#endif
int Less::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -36,6 +36,7 @@ class Less : public Arithmetic {
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore

@ -29,5 +29,15 @@ int LessEqual::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe
return RET_OK;
}
#endif
int LessEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -36,6 +36,7 @@ class LessEqual : public Arithmetic {
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore

@ -29,5 +29,15 @@ int NotEqual::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffer
return RET_OK;
}
#endif
int NotEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->SetFormat(input->GetFormat());
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -36,6 +36,7 @@ class NotEqual : public Arithmetic {
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore

Loading…
Cancel
Save