From: @jonwe
Reviewed-by: @robingrosman,@tom__chen
Signed-off-by: @tom__chen
pull/8337/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 10592a80b9

@ -0,0 +1,82 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CAST_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CAST_GPU_KERNEL_H_
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/cast_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename S, typename T>
class CastGpuKernel : public GpuKernel {
public:
CastGpuKernel() : input_size_(1), output_size_(1) {}
~CastGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
S *input_addr = GetDeviceAddress<S>(inputs, 0);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
Cast(input_size_, input_addr, output_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0);
input_size_ = 1;
for (size_t i = 0; i < input_shapes.size(); i++) {
input_size_ *= input_shapes[i];
}
output_size_ = 1;
for (size_t j = 0; j < output_shapes.size(); j++) {
output_size_ *= output_shapes[j];
}
if (input_size_ != output_size_) {
MS_LOG(EXCEPTION) << "Input size is not equal to output size.";
}
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_ * sizeof(T));
output_size_list_.push_back(output_size_ * sizeof(T));
}
private:
int input_size_;
int output_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CAST_GPU_KERNEL_H_

@ -0,0 +1,177 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <iostream>
#include "backend/kernel_compiler/gpu/cuda_impl/cast_impl.cuh"
#include "runtime/device/gpu/cuda_common.h"
// Cast
template <typename S, typename T>
__global__ void CastKernel(const int input_size, const S *input_addr, T *output_addr) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < input_size; pos += blockDim.x * gridDim.x) {
output_addr[pos] = static_cast<T>(input_addr[pos]);
}
}
template <typename S, typename T>
void Cast(const int input_size, const S *input_addr, T *output_addr, cudaStream_t stream) {
CastKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, input_addr, output_addr);
}
template void Cast(const int input_size, const int8_t *input_addr, int8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int8_t *input_addr, int16_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int8_t *input_addr, int32_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int8_t *input_addr, int64_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int8_t *input_addr, uint8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int8_t *input_addr, uint16_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int8_t *input_addr, uint32_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int8_t *input_addr, uint64_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int8_t *input_addr, float *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int8_t *input_addr, double *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int8_t *input_addr, half *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int8_t *input_addr, bool *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int16_t *input_addr, int8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int16_t *input_addr, int16_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int16_t *input_addr, int32_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int16_t *input_addr, int64_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int16_t *input_addr, uint8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int16_t *input_addr, uint16_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int16_t *input_addr, uint32_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int16_t *input_addr, uint64_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int16_t *input_addr, float *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int16_t *input_addr, double *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int16_t *input_addr, half *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int16_t *input_addr, bool *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int32_t *input_addr, int8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int32_t *input_addr, int16_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int32_t *input_addr, int32_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int32_t *input_addr, int64_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int32_t *input_addr, uint8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int32_t *input_addr, uint16_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int32_t *input_addr, uint32_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int32_t *input_addr, uint64_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int32_t *input_addr, float *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int32_t *input_addr, double *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int32_t *input_addr, half *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int32_t *input_addr, bool *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int64_t *input_addr, int8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int64_t *input_addr, int16_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int64_t *input_addr, int32_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int64_t *input_addr, int64_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int64_t *input_addr, uint8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int64_t *input_addr, uint16_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int64_t *input_addr, uint32_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int64_t *input_addr, uint64_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int64_t *input_addr, float *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int64_t *input_addr, double *output_addr, cudaStream_t stream);
// template void Cast(const int input_size, const int64_t *input_addr, half *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const int64_t *input_addr, bool *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint8_t *input_addr, int8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint8_t *input_addr, int16_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint8_t *input_addr, int32_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint8_t *input_addr, int64_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint8_t *input_addr, uint8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint8_t *input_addr, uint16_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint8_t *input_addr, uint32_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint8_t *input_addr, uint64_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint8_t *input_addr, float *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint8_t *input_addr, double *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint8_t *input_addr, half *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint8_t *input_addr, bool *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint16_t *input_addr, int8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint16_t *input_addr, int16_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint16_t *input_addr, int32_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint16_t *input_addr, int64_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint16_t *input_addr, uint8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint16_t *input_addr, uint16_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint16_t *input_addr, uint32_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint16_t *input_addr, uint64_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint16_t *input_addr, float *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint16_t *input_addr, double *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint16_t *input_addr, half *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint16_t *input_addr, bool *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint32_t *input_addr, int8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint32_t *input_addr, int16_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint32_t *input_addr, int32_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint32_t *input_addr, int64_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint32_t *input_addr, uint8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint32_t *input_addr, uint16_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint32_t *input_addr, uint32_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint32_t *input_addr, uint64_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint32_t *input_addr, float *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint32_t *input_addr, double *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint32_t *input_addr, half *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint32_t *input_addr, bool *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint64_t *input_addr, int8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint64_t *input_addr, int16_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint64_t *input_addr, int32_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint64_t *input_addr, int64_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint64_t *input_addr, uint8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint64_t *input_addr, uint16_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint64_t *input_addr, uint32_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint64_t *input_addr, uint64_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint64_t *input_addr, float *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint64_t *input_addr, double *output_addr, cudaStream_t stream);
// template void Cast(const int input_size, const uint64_t *input_addr, half *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const uint64_t *input_addr, bool *output_addr, cudaStream_t stream);
// template void Cast(const int input_size, const half *input_addr, int8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const half *input_addr, int16_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const half *input_addr, int32_t *output_addr, cudaStream_t stream);
// template void Cast(const int input_size, const half *input_addr, int64_t *output_addr, cudaStream_t stream);
// template void Cast(const int input_size, const half *input_addr, uint8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const half *input_addr, uint16_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const half *input_addr, uint32_t *output_addr, cudaStream_t stream);
// template void Cast(const int input_size, const half *input_addr, uint64_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const half *input_addr, float *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const half *input_addr, double *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const half *input_addr, half *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const half *input_addr, bool *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const float *input_addr, int8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const float *input_addr, int16_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const float *input_addr, int32_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const float *input_addr, int64_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const float *input_addr, uint8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const float *input_addr, uint16_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const float *input_addr, uint32_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const float *input_addr, uint64_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const float *input_addr, float *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const float *input_addr, double *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const float *input_addr, half *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const float *input_addr, bool *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const double *input_addr, int8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const double *input_addr, int16_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const double *input_addr, int32_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const double *input_addr, int64_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const double *input_addr, uint8_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const double *input_addr, uint16_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const double *input_addr, uint32_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const double *input_addr, uint64_t *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const double *input_addr, float *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const double *input_addr, double *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const double *input_addr, half *output_addr, cudaStream_t stream);
template void Cast(const int input_size, const double *input_addr, bool *output_addr, cudaStream_t stream);

@ -0,0 +1,25 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CAST_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CAST_H_
#include <vector>
#include "runtime/device/gpu/cuda_common.h"
template <typename S, typename T>
void Cast(const int input_size, const S *input_addr, T *output_addr, cudaStream_t stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CAST_H_

@ -444,7 +444,8 @@ AbstractBasePtr InferImplMemCpyAsync(const AnalysisEnginePtr &, const PrimitiveP
AbstractBasePtr InferImplCast(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
// GPU has 2 inputs while tbe has 1 only. Skip CheckArgsSize.
// CheckArgsSize(op_name, args_spec_list, 1);
auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(input_x);
MS_EXCEPTION_IF_NULL(input_x->shape());

@ -309,6 +309,7 @@ class Cast(PrimitiveWithInfer):
dst_type = dst_type.element_type()
self.add_prim_attr('DstT', dst_type)
self.add_prim_attr('SrcT', src_type)
self.add_prim_attr('dst_type', dst_type)
value = None
if x['value'] is not None:

Loading…
Cancel
Save