|
|
|
@ -35,40 +35,49 @@ constexpr size_t kOutputNum = 1;
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
int OneHotCPUKernel::Init() {
|
|
|
|
|
if (context_->infer_shape_interrupt_ && !context_->running_) {
|
|
|
|
|
set_need_reinit();
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
// indices depth on_value off_value
|
|
|
|
|
if (in_tensors_.size() != kInputNum || out_tensors_.size() != kOutputNum) {
|
|
|
|
|
MS_LOG(ERROR) << "OneHot input size should be " << kInputNum << ", got " << in_tensors_.size()
|
|
|
|
|
<< ", output size should be" << kOutputNum << ", got " << out_tensors_.size();
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (context_ == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "OneHot context nullptr";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
thread_num_ = context_->thread_num_;
|
|
|
|
|
|
|
|
|
|
auto param = reinterpret_cast<OneHotParameter *>(op_parameter_);
|
|
|
|
|
if (param == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "OneHot op_parameter_ nullptr";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
axis_ = param->axis_;
|
|
|
|
|
|
|
|
|
|
if (!InferShapeDone()) {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
return ReSize();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int OneHotCPUKernel::ReSize() {
|
|
|
|
|
auto indices = in_tensors_.at(0);
|
|
|
|
|
if (indices == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "OneHot inputs[0] indices nullptr";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
auto indices_shape = indices->shape();
|
|
|
|
|
const int indices_rank = static_cast<int>(indices_shape.size());
|
|
|
|
|
if (axis_ < 0) {
|
|
|
|
|
axis_ += indices_rank + 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
outer_size_ = 1;
|
|
|
|
|
for (size_t i = 0; i < static_cast<size_t>(axis_); i++) {
|
|
|
|
|
outer_size_ *= indices_shape[i];
|
|
|
|
|
}
|
|
|
|
|
inner_size_ = indices->ElementsNum() / outer_size_;
|
|
|
|
|
|
|
|
|
|
if (context_ == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "OneHot context nullptr";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
thread_num_ = context_->thread_num_;
|
|
|
|
|
|
|
|
|
|
const int indices_rank = static_cast<int>(in_tensors_.at(0)->shape().size());
|
|
|
|
|
if (axis_ < 0) {
|
|
|
|
|
axis_ += indices_rank + 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|