|
|
|
@ -27,7 +27,8 @@ namespace kernel {
|
|
|
|
|
template <typename T, typename S>
|
|
|
|
|
class UnsortedSegmentSumGpuKernel : public GpuKernel {
|
|
|
|
|
public:
|
|
|
|
|
UnsortedSegmentSumGpuKernel() : input_dim0_(1), input_dim1_(1), output_dim0_(1), output_dim1_(1) {}
|
|
|
|
|
UnsortedSegmentSumGpuKernel()
|
|
|
|
|
: input_dim0_(1), input_dim1_(1), output_dim0_(1), output_dim1_(1), is_null_input_(false) {}
|
|
|
|
|
~UnsortedSegmentSumGpuKernel() override = default;
|
|
|
|
|
|
|
|
|
|
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
|
|
|
@ -36,6 +37,9 @@ class UnsortedSegmentSumGpuKernel : public GpuKernel {
|
|
|
|
|
|
|
|
|
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
|
|
|
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
|
|
|
|
if (is_null_input_) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
|
|
|
|
S *indices_addr = GetDeviceAddress<S>(inputs, 1);
|
|
|
|
|
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
|
|
|
@ -50,6 +54,12 @@ class UnsortedSegmentSumGpuKernel : public GpuKernel {
|
|
|
|
|
|
|
|
|
|
bool Init(const CNodePtr &kernel_node) override {
|
|
|
|
|
auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
|
|
|
|
is_null_input_ = CHECK_NULL_INPUT(input_shapes);
|
|
|
|
|
if (is_null_input_) {
|
|
|
|
|
MS_LOG(WARNING) << "UnsortedSegmentSum input is null";
|
|
|
|
|
InitSizeLists();
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
auto ids_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
|
|
|
|
auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
|
|
|
|
|
|
|
|
@ -83,6 +93,7 @@ class UnsortedSegmentSumGpuKernel : public GpuKernel {
|
|
|
|
|
size_t input_dim1_;
|
|
|
|
|
size_t output_dim0_;
|
|
|
|
|
size_t output_dim1_;
|
|
|
|
|
bool is_null_input_;
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> input_size_list_;
|
|
|
|
|
std::vector<size_t> output_size_list_;
|
|
|
|
|