|
|
@ -45,36 +45,36 @@ class OneHotGpuFwdKernel : public GpuKernel {
|
|
|
|
return true;
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
bool Init(const CNodePtr &kernel_node) override {
|
|
|
|
bool Init(const CNodePtr &kernel_node) override {
|
|
|
|
int axis = static_cast<int>(GetAttr<int64_t>(kernel_node, "axis"));
|
|
|
|
int64_t axis = GetAttr<int64_t>(kernel_node, "axis");
|
|
|
|
auto input = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
|
|
|
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
|
|
|
auto output = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
|
|
|
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
|
|
|
int input_size = SizeToInt(input.size());
|
|
|
|
int64_t input_dims = static_cast<int64_t>(input_shape.size());
|
|
|
|
const int default_axis = -1;
|
|
|
|
if (axis >= input_dims) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "invalid one hot axis value: " << axis << " for input dims size: " << input_shape.size();
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
const int64_t default_axis = -1;
|
|
|
|
|
|
|
|
|
|
|
|
// Compress arbitrary tensor dimensions into three dimensions (left_dims, depth, right_dims).
|
|
|
|
// Compress arbitrary tensor dimensions into three dimensions (left_dims, depth, right_dims).
|
|
|
|
for (int i = 0; i < input_size; i++) {
|
|
|
|
for (size_t i = 0; i < input_shape.size(); i++) {
|
|
|
|
auto dim_size = input[IntToSize(i)];
|
|
|
|
auto dim_size = input_shape[i];
|
|
|
|
if (axis == default_axis || i < axis) {
|
|
|
|
if (axis == default_axis || i < IntToSize(axis)) {
|
|
|
|
left_dim_size_ *= dim_size;
|
|
|
|
left_dim_size_ *= dim_size;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (axis != default_axis && i >= axis) {
|
|
|
|
if (axis != default_axis && i >= IntToSize(axis)) {
|
|
|
|
right_dim_size_ *= dim_size;
|
|
|
|
right_dim_size_ *= dim_size;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (auto size : input) {
|
|
|
|
for (auto size : input_shape) {
|
|
|
|
input_size_ *= size;
|
|
|
|
input_size_ *= size;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (auto size : output) {
|
|
|
|
for (auto size : output_shape) {
|
|
|
|
output_size_ *= size;
|
|
|
|
output_size_ *= size;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (axis >= input_size) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "invalid one hot axis value: " << axis << " for input dims size: " << input.size();
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (axis == default_axis) {
|
|
|
|
if (axis == default_axis) {
|
|
|
|
depth_ = output[output.size() - 1];
|
|
|
|
depth_ = output_shape[output_shape.size() - 1];
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
depth_ = output[IntToSize(axis)];
|
|
|
|
depth_ = output_shape[IntToSize(axis)];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
InitSizeLists();
|
|
|
|
InitSizeLists();
|
|
|
|
return true;
|
|
|
|
return true;
|
|
|
|