|
|
|
@ -31,14 +31,15 @@ using mindspore::schema::PrimitiveType_OneHot;
|
|
|
|
|
namespace mindspore::kernel {
|
|
|
|
|
namespace {
|
|
|
|
|
constexpr size_t kInputNum = 4;
|
|
|
|
|
constexpr size_t kInputNumOpt = 3;
|
|
|
|
|
constexpr size_t kOutputNum = 1;
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
int OneHotCPUKernel::Init() {
|
|
|
|
|
// indices depth on_value off_value
|
|
|
|
|
if (in_tensors_.size() != kInputNum || out_tensors_.size() != kOutputNum) {
|
|
|
|
|
MS_LOG(ERROR) << "OneHot input size should be " << kInputNum << ", got " << in_tensors_.size()
|
|
|
|
|
<< ", output size should be" << kOutputNum << ", got " << out_tensors_.size();
|
|
|
|
|
if ((in_tensors_.size() != kInputNum && in_tensors_.size() != kInputNumOpt) || out_tensors_.size() != kOutputNum) {
|
|
|
|
|
MS_LOG(ERROR) << "OneHot input size should be " << kInputNum << " or " << kInputNumOpt << ", got "
|
|
|
|
|
<< in_tensors_.size() << ", output size should be" << kOutputNum << ", got " << out_tensors_.size();
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (context_ == nullptr) {
|
|
|
|
@ -132,27 +133,42 @@ int OneHotCPUKernel::GetParams() {
|
|
|
|
|
}
|
|
|
|
|
one_hot_param->depth_ = *depth;
|
|
|
|
|
|
|
|
|
|
auto on_value_tensor = in_tensors_.at(2);
|
|
|
|
|
if (on_value_tensor == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "OneHot inputs[2] on_value nullptr";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
const float *on_value = static_cast<float *>(on_value_tensor->MutableData());
|
|
|
|
|
if (on_value == nullptr) {
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
one_hot_param->on_value_ = *on_value;
|
|
|
|
|
|
|
|
|
|
auto off_value_tensor = in_tensors_.at(3);
|
|
|
|
|
if (off_value_tensor == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "OneHot inputs[3] off_value nullptr";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
const float *off_value = static_cast<float *>(off_value_tensor->MutableData());
|
|
|
|
|
if (off_value == nullptr) {
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
if (in_tensors_.size() == kInputNum) {
|
|
|
|
|
auto on_value_tensor = in_tensors_.at(2);
|
|
|
|
|
if (on_value_tensor == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "OneHot inputs[2] on_value nullptr";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
const float *on_value = static_cast<float *>(on_value_tensor->MutableData());
|
|
|
|
|
if (on_value == nullptr) {
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
one_hot_param->on_value_ = *on_value;
|
|
|
|
|
|
|
|
|
|
auto off_value_tensor = in_tensors_.at(3);
|
|
|
|
|
if (off_value_tensor == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "OneHot inputs[3] off_value nullptr";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
const float *off_value = static_cast<float *>(off_value_tensor->MutableData());
|
|
|
|
|
if (off_value == nullptr) {
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
one_hot_param->off_value_ = *off_value;
|
|
|
|
|
} else {
|
|
|
|
|
auto off_on_tensor = in_tensors_.at(2);
|
|
|
|
|
if (off_on_tensor == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "OneHot inputs[2] on_value nullptr";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
const int64_t *off_on_values = static_cast<int64_t *>(off_on_tensor->MutableData());
|
|
|
|
|
if (off_on_values == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "OneHot input[2] data is nullptr";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
one_hot_param->off_value_ = static_cast<float>(off_on_values[0]);
|
|
|
|
|
one_hot_param->on_value_ = static_cast<float>(off_on_values[1]);
|
|
|
|
|
}
|
|
|
|
|
one_hot_param->off_value_ = *off_value;
|
|
|
|
|
|
|
|
|
|
one_hot_param->outer_size_ = outer_size_;
|
|
|
|
|
one_hot_param->inner_size_ = inner_size_;
|
|
|
|
|