|
|
|
@ -63,19 +63,21 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
|
|
|
|
|
if (!CheckParam(kernel_node)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
axis_ = GetAttr<int>(kernel_node, "axis");
|
|
|
|
|
if (axis_ < 0) {
|
|
|
|
|
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
|
|
|
|
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
|
|
|
|
axis_ += SizeToInt(input_shape.size());
|
|
|
|
|
}
|
|
|
|
|
auto origin_data_format = AnfAlgo::GetOriginDataFormat(kernel_node);
|
|
|
|
|
auto input_format = AnfAlgo::GetInputFormat(kernel_node, 0);
|
|
|
|
|
axis_ = AxisTransform(origin_data_format, input_format, axis_);
|
|
|
|
|
|
|
|
|
|
input_num_ = SizeToInt(AnfAlgo::GetInputTensorNum(kernel_node));
|
|
|
|
|
inputs_host_ = std::make_unique<T *[]>(input_num_);
|
|
|
|
|
len_axis_ = std::make_unique<int[]>(input_num_);
|
|
|
|
|
for (int i = 0; i < input_num_; i++) {
|
|
|
|
|
size_t input_size = 1;
|
|
|
|
|
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i);
|
|
|
|
|
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, i);
|
|
|
|
|
for (size_t j = 0; j < input_shape.size(); j++) {
|
|
|
|
|
input_size *= input_shape[j];
|
|
|
|
|
}
|
|
|
|
@ -85,7 +87,7 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
|
|
|
|
|
workspace_size_list_.push_back(sizeof(T *) * input_num_);
|
|
|
|
|
workspace_size_list_.push_back(sizeof(int) * input_num_);
|
|
|
|
|
|
|
|
|
|
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
|
|
|
|
auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
|
|
|
|
output_size_ = 1;
|
|
|
|
|
for (int i = 0; i < SizeToInt(output_shape.size()); i++) {
|
|
|
|
|
output_size_ *= output_shape[i];
|
|
|
|
@ -98,7 +100,6 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
output_size_list_.push_back(output_size_ * sizeof(T));
|
|
|
|
|
|
|
|
|
|
InitSizeLists();
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|