!9492 Fix implement bug in pack and unpack

From: @yuan_shen_zhou
Reviewed-by: @liangchenghui,@linqingke
Signed-off-by: @liangchenghui
pull/9492/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 9d1e6ca59f

@ -18,23 +18,38 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
MS_REG_GPU_KERNEL_ONE(
Pack, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
PackGpuFwdKernel, float)
MS_REG_GPU_KERNEL_ONE(
Pack, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
PackGpuFwdKernel, half)
MS_REG_GPU_KERNEL_ONE(Pack, MS_REG_GPU_KERNEL_ONE(Pack,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
PackGpuFwdKernel, int) PackGpuFwdKernel, int8_t)
MS_REG_GPU_KERNEL_ONE(Pack, MS_REG_GPU_KERNEL_ONE(Pack,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
PackGpuFwdKernel, int16_t) PackGpuFwdKernel, int16_t)
MS_REG_GPU_KERNEL_ONE(Pack,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
PackGpuFwdKernel, int)
MS_REG_GPU_KERNEL_ONE(Pack,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
PackGpuFwdKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(Pack, MS_REG_GPU_KERNEL_ONE(Pack,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
PackGpuFwdKernel, uchar) PackGpuFwdKernel, uint8_t)
MS_REG_GPU_KERNEL_ONE(Pack, MS_REG_GPU_KERNEL_ONE(Pack,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
PackGpuFwdKernel, bool) PackGpuFwdKernel, bool)
MS_REG_GPU_KERNEL_ONE(
Pack, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
PackGpuFwdKernel, uint16_t)
MS_REG_GPU_KERNEL_ONE(
Pack, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
PackGpuFwdKernel, uint32_t)
MS_REG_GPU_KERNEL_ONE(
Pack, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
PackGpuFwdKernel, uint64_t)
MS_REG_GPU_KERNEL_ONE(
Pack, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
PackGpuFwdKernel, half)
MS_REG_GPU_KERNEL_ONE(
Pack, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
PackGpuFwdKernel, float)
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

@ -46,7 +46,7 @@ class PackGpuFwdKernel : public GpuKernel {
inputs_host_.get(), sizeof(T *) * input_num_, cudaMemcpyHostToDevice, inputs_host_.get(), sizeof(T *) * input_num_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)), reinterpret_cast<cudaStream_t>(stream_ptr)),
"Pack opt cudaMemcpyAsync inputs failed"); "Pack opt cudaMemcpyAsync inputs failed");
PackKernel(SizeToInt(output_size_), input_num_, dims_behind_axis_, inputs_array, output, PackKernel(output_size_, input_num_, dims_behind_axis_, inputs_array, output,
reinterpret_cast<cudaStream_t>(stream_ptr)); reinterpret_cast<cudaStream_t>(stream_ptr));
return true; return true;
} }
@ -58,19 +58,22 @@ class PackGpuFwdKernel : public GpuKernel {
axis_ = static_cast<int32_t>(GetAttr<int64_t>(kernel_node, "axis")); axis_ = static_cast<int32_t>(GetAttr<int64_t>(kernel_node, "axis"));
if (axis_ < 0) { if (axis_ < 0) {
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
axis_ += SizeToInt(input_shape.size()); axis_ += (SizeToInt(input_shape.size()) + 1);
} }
auto origin_data_format = AnfAlgo::GetOriginDataFormat(kernel_node); auto origin_data_format = AnfAlgo::GetOriginDataFormat(kernel_node);
auto input_format = AnfAlgo::GetInputFormat(kernel_node, 0); auto input_format = AnfAlgo::GetInputFormat(kernel_node, 0);
axis_ = AxisTransform(origin_data_format, input_format, axis_); axis_ = AxisTransform(origin_data_format, input_format, axis_);
input_num_ = SizeToInt(AnfAlgo::GetInputTensorNum(kernel_node)); input_num_ = AnfAlgo::GetInputTensorNum(kernel_node);
inputs_host_ = std::make_unique<T *[]>(input_num_); inputs_host_ = std::make_unique<T *[]>(input_num_);
for (int i = 0; i < input_num_; i++) { for (size_t i = 0; i < input_num_; i++) {
size_t input_size = 1; size_t input_size = 1;
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, i); auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, i);
for (size_t j = 0; j < input_shape.size(); j++) { for (size_t j = 0; j < input_shape.size(); j++) {
input_size *= input_shape[j]; input_size *= input_shape[j];
if (i == 0 && j >= IntToSize(axis_)) {
dims_behind_axis_ *= input_shape[j];
}
} }
input_size_list_.push_back(input_size * sizeof(T)); input_size_list_.push_back(input_size * sizeof(T));
} }
@ -78,11 +81,8 @@ class PackGpuFwdKernel : public GpuKernel {
auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
output_size_ = 1; output_size_ = 1;
for (int i = 0; i < SizeToInt(output_shape.size()); i++) { for (size_t i = 0; i < output_shape.size(); i++) {
output_size_ *= output_shape[i]; output_size_ *= output_shape[i];
if (i > axis_ + 1) {
dims_behind_axis_ *= output_shape[i];
}
} }
output_size_list_.push_back(output_size_ * sizeof(T)); output_size_list_.push_back(output_size_ * sizeof(T));
InitSizeLists(); InitSizeLists();
@ -102,9 +102,9 @@ class PackGpuFwdKernel : public GpuKernel {
return true; return true;
} }
int axis_; int axis_;
int input_num_; size_t input_num_;
size_t output_size_; size_t output_size_;
int dims_behind_axis_; size_t dims_behind_axis_;
std::unique_ptr<T *[]> inputs_host_; std::unique_ptr<T *[]> inputs_host_;
std::vector<size_t> input_size_list_; std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_; std::vector<size_t> output_size_list_;

@ -18,23 +18,38 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
MS_REG_GPU_KERNEL_ONE(
Unpack, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnpackGpuFwdKernel, float)
MS_REG_GPU_KERNEL_ONE(
Unpack, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnpackGpuFwdKernel, half)
MS_REG_GPU_KERNEL_ONE(Unpack, MS_REG_GPU_KERNEL_ONE(Unpack,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
UnpackGpuFwdKernel, int) UnpackGpuFwdKernel, int8_t)
MS_REG_GPU_KERNEL_ONE(Unpack, MS_REG_GPU_KERNEL_ONE(Unpack,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
UnpackGpuFwdKernel, int16_t) UnpackGpuFwdKernel, int16_t)
MS_REG_GPU_KERNEL_ONE(Unpack,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
UnpackGpuFwdKernel, int)
MS_REG_GPU_KERNEL_ONE(Unpack,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
UnpackGpuFwdKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(Unpack, MS_REG_GPU_KERNEL_ONE(Unpack,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
UnpackGpuFwdKernel, uchar) UnpackGpuFwdKernel, uint8_t)
MS_REG_GPU_KERNEL_ONE(Unpack, MS_REG_GPU_KERNEL_ONE(Unpack,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
UnpackGpuFwdKernel, bool) UnpackGpuFwdKernel, bool)
MS_REG_GPU_KERNEL_ONE(
Unpack, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
UnpackGpuFwdKernel, uint16_t)
MS_REG_GPU_KERNEL_ONE(
Unpack, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
UnpackGpuFwdKernel, uint32_t)
MS_REG_GPU_KERNEL_ONE(
Unpack, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
UnpackGpuFwdKernel, uint64_t)
MS_REG_GPU_KERNEL_ONE(
Unpack, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnpackGpuFwdKernel, half)
MS_REG_GPU_KERNEL_ONE(
Unpack, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnpackGpuFwdKernel, float)
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

@ -46,7 +46,7 @@ class UnpackGpuFwdKernel : public GpuKernel {
outputs_host_.get(), sizeof(T *) * output_num_, cudaMemcpyHostToDevice, outputs_host_.get(), sizeof(T *) * output_num_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)), reinterpret_cast<cudaStream_t>(stream_ptr)),
"Unpack opt cudaMemcpyAsync outputs failed"); "Unpack opt cudaMemcpyAsync outputs failed");
UnpackKernel(SizeToInt(input_size_), output_num_, dims_after_axis_, outputs_array, input, UnpackKernel(input_size_, output_num_, dims_after_axis_, outputs_array, input,
reinterpret_cast<cudaStream_t>(stream_ptr)); reinterpret_cast<cudaStream_t>(stream_ptr));
return true; return true;
} }
@ -64,9 +64,9 @@ class UnpackGpuFwdKernel : public GpuKernel {
auto input_format = AnfAlgo::GetInputFormat(kernel_node, 0); auto input_format = AnfAlgo::GetInputFormat(kernel_node, 0);
axis_ = AxisTransform(origin_data_format, input_format, axis_); axis_ = AxisTransform(origin_data_format, input_format, axis_);
output_num_ = static_cast<int32_t>(GetAttr<int64_t>(kernel_node, "num")); output_num_ = LongToSize(GetAttr<int64_t>(kernel_node, "num"));
outputs_host_ = std::make_unique<T *[]>(output_num_); outputs_host_ = std::make_unique<T *[]>(output_num_);
for (int i = 0; i < output_num_; i++) { for (size_t i = 0; i < output_num_; i++) {
size_t _size = 1; size_t _size = 1;
auto _shape = AnfAlgo::GetOutputDeviceShape(kernel_node, i); auto _shape = AnfAlgo::GetOutputDeviceShape(kernel_node, i);
for (size_t j = 0; j < _shape.size(); j++) { for (size_t j = 0; j < _shape.size(); j++) {
@ -77,9 +77,9 @@ class UnpackGpuFwdKernel : public GpuKernel {
workspace_size_list_.push_back(sizeof(T *) * output_num_); workspace_size_list_.push_back(sizeof(T *) * output_num_);
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
for (int i = 0; i < SizeToInt(input_shape.size()); i++) { for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i]; input_size_ *= input_shape[i];
if (i > axis_) { if (i > IntToSize(axis_)) {
dims_after_axis_ *= input_shape[i]; dims_after_axis_ *= input_shape[i];
} }
} }
@ -101,9 +101,9 @@ class UnpackGpuFwdKernel : public GpuKernel {
return true; return true;
} }
int axis_; int axis_;
int output_num_; size_t output_num_;
size_t input_size_; size_t input_size_;
int dims_after_axis_; size_t dims_after_axis_;
std::unique_ptr<T *[]> outputs_host_; std::unique_ptr<T *[]> outputs_host_;
std::vector<size_t> input_size_list_; std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_; std::vector<size_t> output_size_list_;

@ -19,39 +19,55 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "backend/kernel_compiler/gpu/cuda_impl/pack.cuh" #include "backend/kernel_compiler/gpu/cuda_impl/pack.cuh"
template <typename T> template <typename T>
__global__ void Pack(const int size, const int input_num, const int dims_behind_axis, T** inputs, T* output) { __global__ void Pack(const size_t size, const size_t input_num, const size_t dims_behind_axis, T** inputs, T* output) {
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
int cycle = pos / (input_num * dims_behind_axis); size_t cur_input_index = pos / dims_behind_axis % input_num;
int cur_input_index = pos % (input_num * dims_behind_axis) / dims_behind_axis; size_t cycle_len = input_num * dims_behind_axis;
int local_index = pos % (input_num * dims_behind_axis) % dims_behind_axis; size_t local_index = pos / cycle_len * dims_behind_axis + pos % cycle_len % dims_behind_axis;
output[pos] = inputs[cur_input_index][cycle * dims_behind_axis + local_index]; output[pos] = inputs[cur_input_index][local_index];
} }
return; return;
} }
template <typename T> template <typename T>
void PackKernel(const int size, const int input_num, void PackKernel(const size_t size, const size_t input_num,
const int dims_behind_axis, T** inputs, T* output, const size_t dims_behind_axis, T** inputs, T* output,
cudaStream_t cuda_stream) { cudaStream_t cuda_stream) {
Pack<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input_num, dims_behind_axis, inputs, output); Pack<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input_num, dims_behind_axis, inputs, output);
return; return;
} }
template void PackKernel(const int size, const int input_num,
const int dims_behind_axis, float** inputs, float* output, template void PackKernel(const size_t size, const size_t input_num,
const size_t dims_behind_axis, int8_t** inputs, int8_t* output,
cudaStream_t cuda_stream);
template void PackKernel(const size_t size, const size_t input_num,
const size_t dims_behind_axis, int16_t** inputs, int16_t* output,
cudaStream_t cuda_stream);
template void PackKernel(const size_t size, const size_t input_num,
const size_t dims_behind_axis, int** inputs, int* output,
cudaStream_t cuda_stream);
template void PackKernel(const size_t size, const size_t input_num,
const size_t dims_behind_axis, int64_t** inputs, int64_t* output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void PackKernel(const int size, const int input_num, template void PackKernel(const size_t size, const size_t input_num,
const int dims_behind_axis, int** inputs, int* output, const size_t dims_behind_axis, uint8_t** inputs, uint8_t* output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void PackKernel(const int size, const int input_num, template void PackKernel(const size_t size, const size_t input_num,
const int dims_behind_axis, half** inputs, half* output, const size_t dims_behind_axis, uint16_t** inputs, uint16_t* output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void PackKernel(const int size, const int input_num, template void PackKernel(const size_t size, const size_t input_num,
const int dims_behind_axis, short** inputs, short* output, // NOLINT const size_t dims_behind_axis, uint32_t** inputs, uint32_t* output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void PackKernel(const int size, const int input_num, template void PackKernel(const size_t size, const size_t input_num,
const int dims_behind_axis, unsigned char** inputs, unsigned char* output, const size_t dims_behind_axis, uint64_t** inputs, uint64_t* output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void PackKernel(const int size, const int input_num, template void PackKernel(const size_t size, const size_t input_num,
const int dims_behind_axis, bool** inputs, bool* output, const size_t dims_behind_axis, half** inputs, half* output,
cudaStream_t cuda_stream);
template void PackKernel(const size_t size, const size_t input_num,
const size_t dims_behind_axis, float** inputs, float* output,
cudaStream_t cuda_stream);
template void PackKernel(const size_t size, const size_t input_num,
const size_t dims_behind_axis, bool** inputs, bool* output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);

@ -19,9 +19,9 @@
#include "runtime/device/gpu/cuda_common.h" #include "runtime/device/gpu/cuda_common.h"
template <typename T> template <typename T>
void PackKernel(const int size, void PackKernel(const size_t size,
const int input_num, const size_t input_num,
const int dims_behind_axis, const size_t dims_behind_axis,
T** inputs, T** inputs,
T* output, T* output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);

@ -19,41 +19,56 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "backend/kernel_compiler/gpu/cuda_impl/unpack.cuh" #include "backend/kernel_compiler/gpu/cuda_impl/unpack.cuh"
template <typename T> template <typename T>
__global__ void Unpack(const int size, const int output_num, __global__ void Unpack(const size_t size, const size_t output_num,
const int dims_after_axis, T** outputs, const T* input) { const size_t dims_after_axis, T** outputs, const T* input) {
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
int cycle = pos / (output_num * dims_after_axis); size_t cur_input_index = pos / dims_after_axis % output_num;
int cur_output_index = pos % (output_num * dims_after_axis) / dims_after_axis; size_t cycle_len = output_num * dims_after_axis;
int local_index = pos % (output_num * dims_after_axis) % dims_after_axis; size_t local_index = pos / cycle_len * dims_after_axis + pos % cycle_len % dims_after_axis;
outputs[cur_output_index][cycle * dims_after_axis + local_index] = input[pos]; outputs[cur_input_index][local_index] = input[pos];
} }
return; return;
} }
template <typename T> template <typename T>
void UnpackKernel(const int size, const int output_num, void UnpackKernel(const size_t size, const size_t output_num,
const int dims_after_axis, T** outputs, const T* input, const size_t dims_after_axis, T** outputs, const T* input,
cudaStream_t cuda_stream) { cudaStream_t cuda_stream) {
Unpack<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, output_num, Unpack<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, output_num,
dims_after_axis, outputs, input); dims_after_axis, outputs, input);
return; return;
} }
template void UnpackKernel(const int size, const int output_num, template void UnpackKernel(const size_t size, const size_t output_num,
const int dims_after_axis, float** outputs, const float* input, const size_t dims_after_axis, int8_t** outputs, const int8_t* input,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void UnpackKernel(const int size, const int output_num, template void UnpackKernel(const size_t size, const size_t output_num,
const int dims_after_axis, half** outputs, const half* input, const size_t dims_after_axis, int16_t** outputs, const int16_t* input,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void UnpackKernel(const int size, const int output_num, template void UnpackKernel(const size_t size, const size_t output_num,
const int dims_after_axis, int** outputs, const int* input, const size_t dims_after_axis, int** outputs, const int* input,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void UnpackKernel(const int size, const int output_num, template void UnpackKernel(const size_t size, const size_t output_num,
const int dims_after_axis, int16_t** outputs, const int16_t* input, const size_t dims_after_axis, int64_t** outputs, const int64_t* input,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void UnpackKernel(const int size, const int output_num, template void UnpackKernel(const size_t size, const size_t output_num,
const int dims_after_axis, unsigned char** outputs, const unsigned char* input, const size_t dims_after_axis, uint8_t** outputs, const uint8_t* input,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void UnpackKernel(const int size, const int output_num, template void UnpackKernel(const size_t size, const size_t output_num,
const int dims_after_axis, bool** outputs, const bool* input, const size_t dims_after_axis, uint16_t** outputs, const uint16_t* input,
cudaStream_t cuda_stream);
template void UnpackKernel(const size_t size, const size_t output_num,
const size_t dims_after_axis, uint32_t** outputs, const uint32_t* input,
cudaStream_t cuda_stream);
template void UnpackKernel(const size_t size, const size_t output_num,
const size_t dims_after_axis, uint64_t** outputs, const uint64_t* input,
cudaStream_t cuda_stream);
template void UnpackKernel(const size_t size, const size_t output_num,
const size_t dims_after_axis, half** outputs, const half* input,
cudaStream_t cuda_stream);
template void UnpackKernel(const size_t size, const size_t output_num,
const size_t dims_after_axis, float** outputs, const float* input,
cudaStream_t cuda_stream);
template void UnpackKernel(const size_t size, const size_t output_num,
const size_t dims_after_axis, bool** outputs, const bool* input,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);

@ -19,7 +19,7 @@
#include "runtime/device/gpu/cuda_common.h" #include "runtime/device/gpu/cuda_common.h"
template <typename T> template <typename T>
void UnpackKernel(const int size, const int output_num, void UnpackKernel(const size_t size, const size_t output_num,
const int dims_after_axis, T** outputs, const T* input, const size_t dims_after_axis, T** outputs, const T* input,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNPACKIMPL_H_ #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNPACKIMPL_H_

@ -47,20 +47,20 @@ def pack(nptype):
pack_ = PackNet(nptype) pack_ = PackNet(nptype)
output = pack_() output = pack_()
expect = np.array([[[[[0, 0], expect = np.array([[[[[0, 0],
[0, 1]], [0, 0]],
[[0, 0], [[0, 1],
[2, 3]]], [2, 3]]],
[[[0, 0], [[[0, 0],
[4, 5]], [0, 0]],
[[0, 0], [[4, 5],
[6, 7]]]], [6, 7]]]],
[[[[0, 0], [[[[0, 0],
[8, 9]], [0, 0]],
[[0, 0], [[8, 9],
[10, 11]]], [10, 11]]],
[[[0, 0], [[[0, 0],
[12, 13]], [0, 0]],
[[0, 0], [[12, 13],
[14, 15]]]]]).astype(nptype) [14, 15]]]]]).astype(nptype)
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()
@ -71,20 +71,20 @@ def pack_pynative(nptype):
x1 = Tensor(x1) x1 = Tensor(x1)
x2 = Tensor(np.arange(16).reshape(2, 2, 2, 2).astype(nptype)) x2 = Tensor(np.arange(16).reshape(2, 2, 2, 2).astype(nptype))
expect = np.array([[[[[0, 0], expect = np.array([[[[[0, 0],
[0, 1]], [0, 0]],
[[0, 0], [[0, 1],
[2, 3]]], [2, 3]]],
[[[0, 0], [[[0, 0],
[4, 5]], [0, 0]],
[[0, 0], [[4, 5],
[6, 7]]]], [6, 7]]]],
[[[[0, 0], [[[[0, 0],
[8, 9]], [0, 0]],
[[0, 0], [[8, 9],
[10, 11]]], [10, 11]]],
[[[0, 0], [[[0, 0],
[12, 13]], [0, 0]],
[[0, 0], [[12, 13],
[14, 15]]]]]).astype(nptype) [14, 15]]]]]).astype(nptype)
output = P.Pack(axis=2)((x1, x2)) output = P.Pack(axis=2)((x1, x2))
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()

Loading…
Cancel
Save