|
|
|
@ -27,7 +27,8 @@ namespace kernel {
|
|
|
|
|
template <typename T>
|
|
|
|
|
class SliceGpuFwdKernel : public GpuKernel {
|
|
|
|
|
public:
|
|
|
|
|
SliceGpuFwdKernel() : is_strided_slice_(false), input_size_(0), output_size_(0), workspace_size_(0) {}
|
|
|
|
|
SliceGpuFwdKernel()
|
|
|
|
|
: is_strided_slice_(false), is_null_input_(false), input_size_(0), output_size_(0), workspace_size_(0) {}
|
|
|
|
|
~SliceGpuFwdKernel() override = default;
|
|
|
|
|
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
|
|
|
|
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
|
|
|
@ -35,6 +36,9 @@ class SliceGpuFwdKernel : public GpuKernel {
|
|
|
|
|
|
|
|
|
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
|
|
|
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
|
|
|
|
if (is_null_input_) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
T *input = GetDeviceAddress<T>(inputs, 0);
|
|
|
|
|
T *output = GetDeviceAddress<T>(outputs, 0);
|
|
|
|
|
if (is_strided_slice_) {
|
|
|
|
@ -79,7 +83,11 @@ class SliceGpuFwdKernel : public GpuKernel {
|
|
|
|
|
if (size_[i] < 0) {
|
|
|
|
|
size_[i] = (size_[i] + input_shape_[i]) > 0 ? (size_[i] + input_shape_[i]) : 0;
|
|
|
|
|
}
|
|
|
|
|
if (size_[i] == 0) {
|
|
|
|
|
if (begin_[i] == size_[i] && is_strided_slice_) {
|
|
|
|
|
MS_LOG(WARNING) << "Output is null.";
|
|
|
|
|
is_null_input_ = true;
|
|
|
|
|
}
|
|
|
|
|
if (size_[i] == 0 && strides_[i] > 0) {
|
|
|
|
|
size_[i] = begin_[i] + 1;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -143,6 +151,7 @@ class SliceGpuFwdKernel : public GpuKernel {
|
|
|
|
|
std::vector<size_t> workspace_size_list_;
|
|
|
|
|
|
|
|
|
|
bool is_strided_slice_;
|
|
|
|
|
bool is_null_input_;
|
|
|
|
|
size_t input_size_;
|
|
|
|
|
size_t output_size_;
|
|
|
|
|
size_t workspace_size_;
|
|
|
|
|