|
|
|
@ -14,8 +14,8 @@
|
|
|
|
|
* limitations under the License.
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#ifndef MINDSPORE_GATHER_GPU_KERNEL_H
|
|
|
|
|
#define MINDSPORE_GATHER_GPU_KERNEL_H
|
|
|
|
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GATHER_V2_GPU_KERNEL_H_
|
|
|
|
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GATHER_V2_GPU_KERNEL_H_
|
|
|
|
|
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include <algorithm>
|
|
|
|
@ -41,45 +41,17 @@ class GatherV2GpuFwdKernel : public GpuKernel {
|
|
|
|
|
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
|
|
|
|
S *indices_addr = GetDeviceAddress<S>(inputs, 1);
|
|
|
|
|
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
|
|
|
|
|
|
|
|
|
if (is_dynamic_shape_) {
|
|
|
|
|
// if we are in dynamic shape mode, we don't know dims_, so we need to store the input_shape_ and indices_shape_,
|
|
|
|
|
// and axis_ in the workspace to calculate dims_
|
|
|
|
|
size_t *input_shape_device_address = GetDeviceAddress<size_t>(workspace, 0);
|
|
|
|
|
size_t *indices_shape_device_address = GetDeviceAddress<size_t>(workspace, 1);
|
|
|
|
|
int64_t *axis_device_address = GetDeviceAddress<int64_t>(workspace, 2);
|
|
|
|
|
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(
|
|
|
|
|
cudaMemcpyAsync(input_shape_device_address, input_shapes_.data(), workspace_size_list_[0],
|
|
|
|
|
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
|
|
|
|
"cudaMemcpyAsync input_shape failed");
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(
|
|
|
|
|
cudaMemcpyAsync(indices_shape_device_address, indices_shapes_.data(), workspace_size_list_[1],
|
|
|
|
|
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
|
|
|
|
"cudaMemcpyAsync indices_shape failed");
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(axis_device_address, &axis_, workspace_size_list_[2],
|
|
|
|
|
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
|
|
|
|
int64_t *axis_device_address = GetDeviceAddress<int64_t>(inputs, 2); // only get this if in dynamic mode
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(&axis_, axis_device_address, sizeof(int64_t), cudaMemcpyDeviceToHost,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
|
|
|
|
"cudaMemcpyAsync axis_ failed");
|
|
|
|
|
|
|
|
|
|
// output shape will be here for us to copy back to host
|
|
|
|
|
size_t *output_shape_device_address = GetDeviceAddress<size_t>(workspace, 3);
|
|
|
|
|
CalGatherV2DynamicShape(input_addr, indices_addr, output_addr, input_shape_device_address, input_shapes_.size(),
|
|
|
|
|
indices_shape_device_address, indices_shapes_.size(), axis_device_address,
|
|
|
|
|
output_shape_device_address, max_output_size_,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
|
|
|
|
|
|
|
|
|
size_t output_rank = input_shapes_.size() - 1 + indices_shapes_.size();
|
|
|
|
|
real_output_shape_.resize(output_rank);
|
|
|
|
|
CHECK_CUDA_RET_WITH_ERROR(
|
|
|
|
|
cudaMemcpyAsync(&real_output_shape_[0], output_shape_device_address, output_rank * sizeof(int32_t),
|
|
|
|
|
cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
|
|
|
|
"Failed to copy gpu memory.");
|
|
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
auto input_dim1 = input_shapes_[IntToSize(axis_)];
|
|
|
|
|
CalGatherV2StaticShape(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], input_dim1,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
|
|
|
|
CHECK_CUDA_RET_WITH_EXCEPT(cudaDeviceSynchronize(), "cudaDeviceSyncFailed - GatherV2 - in dynamic mode");
|
|
|
|
|
Reshape();
|
|
|
|
|
}
|
|
|
|
|
auto input_dim1 = input_shapes_[IntToSize(axis_)];
|
|
|
|
|
GatherV2(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], input_dim1,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
bool Init(const CNodePtr &kernel_node) override {
|
|
|
|
@ -87,33 +59,24 @@ class GatherV2GpuFwdKernel : public GpuKernel {
|
|
|
|
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
|
|
|
|
if (input_num == 3) {
|
|
|
|
|
is_dynamic_shape_ = true;
|
|
|
|
|
} else if (input_num != 2) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherGpuV2FwdKernel needs 2.";
|
|
|
|
|
MS_LOG(INFO) << " GatherGpuV2FwdKernel running in Dynamic Mode.";
|
|
|
|
|
} else if (input_num == 2) {
|
|
|
|
|
MS_LOG(INFO) << " GatherGpuV2FwdKernel running in Normal Mode.";
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherGpuV2FwdKernel needs 2 or 3.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
input_shapes_ = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
|
|
|
|
|
indices_shapes_ = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1);
|
|
|
|
|
output_shapes_ = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0);
|
|
|
|
|
|
|
|
|
|
if (is_dynamic_shape_) {
|
|
|
|
|
c_node_ptr_ = kernel_node;
|
|
|
|
|
size_t input_shape_min = *std::min_element(input_shapes_.begin(), input_shapes_.end());
|
|
|
|
|
max_output_size_ = (GetSize(input_shapes_) / input_shape_min) * GetSize(indices_shapes_);
|
|
|
|
|
} else {
|
|
|
|
|
if (!is_dynamic_shape_) {
|
|
|
|
|
axis_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "axis"));
|
|
|
|
|
if (axis_ < 0) {
|
|
|
|
|
axis_ = axis_ + SizeToInt(input_shapes_.size());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Reshape();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
InitSizeLists();
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
void ResetResource() noexcept override {
|
|
|
|
|
is_dynamic_shape_ = false;
|
|
|
|
|
max_output_size_ = -1;
|
|
|
|
|
input_shapes_.clear();
|
|
|
|
|
indices_shapes_.clear();
|
|
|
|
|
output_shapes_.clear();
|
|
|
|
@ -128,52 +91,32 @@ class GatherV2GpuFwdKernel : public GpuKernel {
|
|
|
|
|
void InitSizeLists() override {
|
|
|
|
|
size_t size = GetSize(input_shapes_);
|
|
|
|
|
input_size_list_.push_back(size);
|
|
|
|
|
|
|
|
|
|
size = GetSize(indices_shapes_);
|
|
|
|
|
input_size_list_.push_back(size);
|
|
|
|
|
|
|
|
|
|
if (is_dynamic_shape_) {
|
|
|
|
|
// add by chenweifeng
|
|
|
|
|
input_size_list_.push_back(sizeof(S));
|
|
|
|
|
|
|
|
|
|
// allocate maximum size needed
|
|
|
|
|
output_size_list_.push_back(max_output_size_);
|
|
|
|
|
|
|
|
|
|
// allocate workspace memory for input, indices, axis, and output shape respectively
|
|
|
|
|
size = GetSize(input_shapes_);
|
|
|
|
|
workspace_size_list_.push_back(size);
|
|
|
|
|
|
|
|
|
|
size = GetSize(indices_shapes_);
|
|
|
|
|
workspace_size_list_.push_back(size);
|
|
|
|
|
|
|
|
|
|
size = sizeof(int32_t);
|
|
|
|
|
workspace_size_list_.push_back(size);
|
|
|
|
|
|
|
|
|
|
size = GetSize(input_shapes_);
|
|
|
|
|
workspace_size_list_.push_back(size);
|
|
|
|
|
} else {
|
|
|
|
|
size = GetSize(output_shapes_);
|
|
|
|
|
output_size_list_.push_back(size);
|
|
|
|
|
input_size_list_.push_back(sizeof(int64_t));
|
|
|
|
|
}
|
|
|
|
|
size = GetSize(output_shapes_);
|
|
|
|
|
output_size_list_.push_back(size);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
void Reshape() {
|
|
|
|
|
if (axis_ < 0) {
|
|
|
|
|
axis_ = axis_ + SizeToInt(input_shapes_.size());
|
|
|
|
|
}
|
|
|
|
|
size_t dim_before_axis = 1;
|
|
|
|
|
for (size_t i = 0; i < IntToSize(axis_); i++) {
|
|
|
|
|
dim_before_axis *= output_shapes_[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t dim_of_indices = 1;
|
|
|
|
|
for (size_t i = 0; i < indices_shapes_.size(); i++) {
|
|
|
|
|
dim_of_indices *= indices_shapes_[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t dim_after_indices = 1;
|
|
|
|
|
for (size_t i = IntToSize(axis_) + indices_shapes_.size(); i < output_shapes_.size(); i++) {
|
|
|
|
|
dim_after_indices *= output_shapes_[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
dims_[0] = dim_before_axis;
|
|
|
|
|
dims_[1] = dim_of_indices;
|
|
|
|
|
dims_[2] = dim_after_indices;
|
|
|
|
@ -193,14 +136,9 @@ class GatherV2GpuFwdKernel : public GpuKernel {
|
|
|
|
|
std::vector<size_t> input_shapes_;
|
|
|
|
|
std::vector<size_t> indices_shapes_;
|
|
|
|
|
std::vector<size_t> output_shapes_;
|
|
|
|
|
|
|
|
|
|
size_t dims_[3] = {};
|
|
|
|
|
int64_t axis_;
|
|
|
|
|
bool is_dynamic_shape_;
|
|
|
|
|
int max_output_size_;
|
|
|
|
|
std::vector<size_t> real_output_shape_;
|
|
|
|
|
CNodePtr c_node_ptr_;
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> input_size_list_;
|
|
|
|
|
std::vector<size_t> output_size_list_;
|
|
|
|
|
std::vector<size_t> workspace_size_list_;
|
|
|
|
@ -208,4 +146,4 @@ class GatherV2GpuFwdKernel : public GpuKernel {
|
|
|
|
|
} // namespace kernel
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|
|
|
|
|
|
#endif // MINDSPORE_GATHER_GPU_KERNEL_H
|
|
|
|
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GATHER_V2_GPU_KERNEL_H_
|
|
|
|
|