|
|
|
@ -63,11 +63,11 @@ void ScatterNdUpdateCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
|
|
|
|
|
bool ScatterNdUpdateCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|
|
|
|
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
|
|
|
|
const std::vector<kernel::AddressPtr> & /*outputs*/) {
|
|
|
|
|
const std::vector<kernel::AddressPtr> &outputs) {
|
|
|
|
|
if (dtype_ == kNumberTypeFloat16) {
|
|
|
|
|
LaunchKernel<float16>(inputs);
|
|
|
|
|
LaunchKernel<float16>(inputs, outputs);
|
|
|
|
|
} else if (dtype_ == kNumberTypeFloat32) {
|
|
|
|
|
LaunchKernel<float>(inputs);
|
|
|
|
|
LaunchKernel<float>(inputs, outputs);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "Only support float16, float32";
|
|
|
|
|
return false;
|
|
|
|
@ -76,7 +76,8 @@ bool ScatterNdUpdateCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs) {
|
|
|
|
|
void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|
|
|
|
const std::vector<kernel::AddressPtr> &outputs) {
|
|
|
|
|
auto x = reinterpret_cast<T *>(inputs[0]->addr);
|
|
|
|
|
auto indices = reinterpret_cast<int *>(inputs[1]->addr);
|
|
|
|
|
auto updates = reinterpret_cast<T *>(inputs[2]->addr);
|
|
|
|
@ -100,6 +101,10 @@ void ScatterNdUpdateCPUKernel::LaunchKernel(const std::vector<AddressPtr> &input
|
|
|
|
|
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto ret = memcpy_s(outputs[0]->addr, outputs[0]->size, x, mem_size);
|
|
|
|
|
if (ret != 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ScatterNdUpdateCPUKernel::Check(const CNodePtr &kernel_node) {
|
|
|
|
|