diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cast_all_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cast_all_impl.cu new file mode 100644 index 0000000000..e7037ef6de --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cast_all_impl.cu @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/cast_all_impl.cuh" + +template +__global__ void CastAll(T** inputs, S** output, const size_t num, const size_t *size) { + for (size_t i = 0; i < num; i++) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size[i]; pos += blockDim.x * gridDim.x) { + output[i][pos] = static_cast(inputs[i][pos]); + } + } +} + +template +void CastAllKernel(T** inputs, S** output, const size_t max, const size_t num, const size_t *size, + cudaStream_t stream) { + CastAll<<>>(inputs, output, num, size); + return; +} +template void CastAllKernel(half** inputs, float** output, const size_t max, const size_t num, + const size_t *size, cudaStream_t stream); +template void CastAllKernel(float** inputs, half** output, const size_t max, const size_t num, + const size_t *size, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cast_all_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cast_all_impl.cuh new file mode 100644 index 0000000000..f1bbdc5671 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cast_all_impl.cuh @@ -0,0 +1,24 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CAST_ALL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CAST_ALL_H_ + +#include +#include "runtime/device/gpu/cuda_common.h" +template +void CastAllKernel(T **inputs, S **output, const size_t max, const size_t num, const size_t *size, cudaStream_t stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CAST_ALL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cast_all_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cast_all_gpu_kernel.cc new file mode 100644 index 0000000000..0e8fb719ab --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cast_all_gpu_kernel.cc @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/math/cast_all_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO( + CastAll, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat16), + CastAllGpuFwdKernel, float, half) +MS_REG_GPU_KERNEL_TWO( + CastAll, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32), + CastAllGpuFwdKernel, half, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cast_all_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cast_all_gpu_kernel.h new file mode 100644 index 0000000000..6d9ac9ef33 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/cast_all_gpu_kernel.h @@ -0,0 +1,104 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CAST_ALL_GPU_KERNEL_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CAST_ALL_GPU_KERNEL_H + +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/cast_all_impl.cuh" +namespace mindspore { +namespace kernel { +template +class CastAllGpuFwdKernel : public GpuKernel { + public: + CastAllGpuFwdKernel() : max_(0), input_size_(0), output_size_(0), num_input_(0) {} + ~CastAllGpuFwdKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + auto stream = reinterpret_cast(stream_ptr); + auto in_addr = std::make_unique(num_input_); + auto out_addr = std::make_unique(num_input_); + for (size_t i = 0; i < num_input_; i++) { + in_addr[i] = GetDeviceAddress(inputs, i); + out_addr[i] = GetDeviceAddress(outputs, i); + } + T **inputs_dev = GetDeviceAddress(workspace, 0); + S **outputs_dev = GetDeviceAddress(workspace, 1); + size_t *size_dev = GetDeviceAddress(workspace, 2); + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(inputs_dev, in_addr.get(), sizeof(T *) * num_input_, cudaMemcpyHostToDevice, stream), + "cudaMemCPY failed") + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(outputs_dev, out_addr.get(), sizeof(S *) * num_input_, cudaMemcpyHostToDevice, stream), + "cudaMemCPY failed") + CHECK_CUDA_RET_WITH_EXCEPT( + cudaMemcpyAsync(size_dev, size_.get(), sizeof(size_t) * num_input_, cudaMemcpyHostToDevice, stream), + "cudaMemCPY failed") + CastAllKernel(inputs_dev, outputs_dev, max_, num_input_, size_dev, stream); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + num_input_ = GetAttr(kernel_node, "n"); + size_ = std::make_unique(num_input_); + for (size_t i = 0; i < num_input_; i++) { + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); + size_t s = 1; + for (auto x : shape) { + s = s * x; + } + if (max_ < s) { + max_ = s; + } + size_[i] = s; + input_size_ = sizeof(T) * s; + output_size_ = sizeof(S) * s; + InitSizeLists(); + } + workspace_size_list_.push_back(sizeof(T *) * num_input_); + workspace_size_list_.push_back(sizeof(S *) * num_input_); + workspace_size_list_.push_back(sizeof(size_t) * num_input_); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + } + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + std::unique_ptr size_; + size_t max_; + size_t input_size_; + size_t output_size_; + size_t num_input_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_CAST_ALL_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/optimizer/gpu/combine_cast_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/combine_cast_fusion.cc new file mode 100644 index 0000000000..89ff6e80d0 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/combine_cast_fusion.cc @@ -0,0 +1,140 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/gpu/combine_cast_fusion.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +namespace { +kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector &node_list) { + std::vector inputs_device_format; + std::vector outputs_device_format; + std::vector inputs_device_type; + std::vector outputs_device_type; + std::vector> outputs_shape; + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + for (size_t idx = 0; idx < node_list.size(); ++idx) { + auto cnode = utils::cast(node_list[idx]); + MS_EXCEPTION_IF_NULL(cnode); + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { + inputs_device_format.push_back(kOpFormat_DEFAULT); + inputs_device_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index)); + } + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { + outputs_device_format.push_back(kOpFormat_DEFAULT); + outputs_device_type.push_back(AnfAlgo::GetOutputInferDataType(cnode, output_index)); + outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index)); + } + } + builder.SetInputsFormat(inputs_device_format); + builder.SetOutputsFormat(outputs_device_format); + builder.SetInputsDeviceType(inputs_device_type); + builder.SetOutputsDeviceType(outputs_device_type); + return builder.Build(); +} + +bool GetDealList(const std::vector &node_list, std::vector> *deal_list) { + std::vector cast_32to16_list; + std::vector cast_16to32_list; + for (auto &cast_node : node_list) { + // currently, we only deal with the construct : [Param->Cast->] to avoid being a cycle. + if (cast_node != nullptr && cast_node->isa() && AnfAlgo::GetCNodeName(cast_node) == "Cast" && + (AnfAlgo::GetInputNode(utils::cast(cast_node), 0))->isa()) { + auto dst = AnfAlgo::GetOutputInferDataType(cast_node, 0); + auto src = AnfAlgo::GetPrevNodeOutputInferDataType(cast_node, 0); + if (dst == kNumberTypeFloat16 && src == kNumberTypeFloat32) { + cast_32to16_list.push_back(cast_node); + } else if (dst == kNumberTypeFloat32 && src == kNumberTypeFloat16) { + cast_16to32_list.push_back(cast_node); + } + } + } + if (cast_32to16_list.size() <= 1 && cast_16to32_list.size() <= 1) { + return false; + } + if (cast_32to16_list.size() > 1) { + deal_list->push_back(cast_32to16_list); + } + if (cast_16to32_list.size() > 1) { + deal_list->push_back(cast_16to32_list); + } + return true; +} +} // namespace +bool CastAllFusion::Run(const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + auto manager = graph->manager(); + std::vector node_list = TopoSort(graph->get_return()); + // 1 get all the cast node + std::vector> deal_list; + if (!GetDealList(node_list, &deal_list)) { + return false; + } + for (auto cast_list : deal_list) { + // 2 create node CastAll + auto prim = std::make_shared("CastAll"); + std::vector inputs = {NewValueNode(prim)}; + // set inputs for CastAll + for (size_t idx = 0; idx < cast_list.size(); ++idx) { + inputs.push_back(AnfAlgo::GetInputNode(utils::cast(cast_list[idx]), 0)); + } + auto cast_all = graph->NewCNode(inputs); + auto kernel_info = std::make_shared(); + MS_EXCEPTION_IF_NULL(kernel_info); + cast_all->set_kernel_info(kernel_info); + AbstractBasePtrList abstract_list; + for (size_t idx = 0; idx < cast_list.size(); ++idx) { + auto cnode = utils::cast(cast_list[idx]); + MS_EXCEPTION_IF_NULL(cnode); + abstract_list.push_back(cnode->abstract()); + } + auto kernel_build_info = GenerateKernelBuildInfo(cast_list); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, cast_all.get()); + auto abstract_tuple = std::make_shared(abstract_list); + MS_EXCEPTION_IF_NULL(abstract_tuple); + cast_all->set_abstract(abstract_tuple); + AnfAlgo::SetNodeAttr("n", MakeValue(cast_list.size()), cast_all); + // 3 replace all the cast by CastAllv tuplegetitem[castall, idx] + for (size_t idx = 0; idx < cast_list.size(); ++idx) { + std::vector tuple_getitem_input; + tuple_getitem_input.push_back(NewValueNode(prim::kPrimTupleGetItem)); + tuple_getitem_input.push_back(cast_all); + auto index = NewValueNode(SizeToInt(idx)); + auto imm = std::make_shared(idx); + auto abstract_scalar = std::make_shared(imm); + MS_EXCEPTION_IF_NULL(abstract_scalar); + index->set_abstract(abstract_scalar); + tuple_getitem_input.push_back(index); + AnfNodePtr tuple_getitem = graph->NewCNode(tuple_getitem_input); + MS_EXCEPTION_IF_NULL(tuple_getitem); + tuple_getitem->set_abstract(cast_list[idx]->abstract()); + if (!manager->Replace(cast_list[idx], tuple_getitem)) { + MS_LOG(EXCEPTION) << "manager replace node failed"; + } + } + } + return true; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/combine_cast_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/combine_cast_fusion.h new file mode 100644 index 0000000000..3d9c2f8650 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/combine_cast_fusion.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_COMBINE_CAST_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_COMBINE_CAST_FUSION_H_ + +#include +#include +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class CastAllFusion : public Pass { + public: + explicit CastAllFusion(const std::string &name) : Pass("cast_all") {} + ~CastAllFusion() override = default; + bool Run(const FuncGraphPtr &graph) override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_COMBINE_CAST_FUSION_H_