|
|
|
@ -82,7 +82,8 @@ Registry OneHotRegistry(schema::PrimitiveType_OneHot, OneHotCreator);
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
constexpr size_t kOneHotInputNum = 4;
|
|
|
|
|
}
|
|
|
|
|
constexpr size_t kOneHotInputNumOpt = 3;
|
|
|
|
|
} // namespace
|
|
|
|
|
int OneHot::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) {
|
|
|
|
|
if (this->primitive_ == nullptr) {
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
@ -90,8 +91,10 @@ int OneHot::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outpu
|
|
|
|
|
|
|
|
|
|
int axis = GetAxis();
|
|
|
|
|
// indices, depth, on_value, off_value
|
|
|
|
|
if (inputs.size() != kOneHotInputNum) {
|
|
|
|
|
MS_LOG(ERROR) << "OneHot got inputs num " << inputs.size() << ", should be " << kOneHotInputNum;
|
|
|
|
|
// indices, depth, on_off_value(contain 2 values);
|
|
|
|
|
if (inputs.size() != kOneHotInputNum && inputs.size() != kOneHotInputNumOpt) {
|
|
|
|
|
MS_LOG(ERROR) << "OneHot got inputs num " << inputs.size() << ", should be " << kOneHotInputNum << " or "
|
|
|
|
|
<< kOneHotInputNumOpt;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
auto depth_tensor = inputs.at(1);
|
|
|
|
|