From 5a8c8997891ba10852ce23a64e6f5fcf3ec743a2 Mon Sep 17 00:00:00 2001 From: TFBunny Date: Tue, 2 Feb 2021 18:35:46 -0500 Subject: [PATCH] Rework GPU print, supporting pynative mode and graph mode --- .../gpu/debug/print_gpu_kernel.cc | 34 ++--- .../gpu/debug/print_gpu_kernel.h | 103 ++++++++++--- mindspore/ops/operations/debug_ops.py | 3 +- tests/st/ops/gpu/test_print_op.py | 135 ++++++++++++++++++ 4 files changed, 236 insertions(+), 39 deletions(-) create mode 100644 tests/st/ops/gpu/test_print_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/debug/print_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/debug/print_gpu_kernel.cc index 62bb23bd44..3497a5a78a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/debug/print_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/debug/print_gpu_kernel.cc @@ -19,37 +19,37 @@ namespace mindspore { namespace kernel { MS_REG_GPU_KERNEL_ONE(Print, - KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32), + PrintGpuKernel, bool) +MS_REG_GPU_KERNEL_ONE(Print, + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32), PrintGpuKernel, int8_t) MS_REG_GPU_KERNEL_ONE(Print, - KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32), PrintGpuKernel, int16_t) MS_REG_GPU_KERNEL_ONE(Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), PrintGpuKernel, int) MS_REG_GPU_KERNEL_ONE(Print, - KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), PrintGpuKernel, int64_t) MS_REG_GPU_KERNEL_ONE(Print, - KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32), PrintGpuKernel, uint8_t) MS_REG_GPU_KERNEL_ONE(Print, - KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), - PrintGpuKernel, bool) -MS_REG_GPU_KERNEL_ONE( - Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), - PrintGpuKernel, uint16_t) -MS_REG_GPU_KERNEL_ONE( - Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), - PrintGpuKernel, uint32_t) -MS_REG_GPU_KERNEL_ONE( - Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), - PrintGpuKernel, uint64_t) + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32), + PrintGpuKernel, uint16_t) +MS_REG_GPU_KERNEL_ONE(Print, + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32), + PrintGpuKernel, uint32_t) +MS_REG_GPU_KERNEL_ONE(Print, + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32), + PrintGpuKernel, uint64_t) MS_REG_GPU_KERNEL_ONE( - Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), PrintGpuKernel, half) MS_REG_GPU_KERNEL_ONE( - Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), PrintGpuKernel, float) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/debug/print_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/debug/print_gpu_kernel.h index 8aa9fda1e1..d409193f98 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/debug/print_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/debug/print_gpu_kernel.h @@ -17,11 +17,17 @@ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DEBUG_PRINT_GPU_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DEBUG_PRINT_GPU_KERNEL_H_ +#include +#include +#include #include #include +#include "ir/tensor.h" #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +using mindspore::tensor::Tensor; + namespace mindspore { namespace kernel { template @@ -37,19 +43,42 @@ class PrintGpuKernel : public GpuKernel { bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, void *stream_ptr) override { VARIABLE_NOT_USED(workspace); - VARIABLE_NOT_USED(outputs); for (size_t i = 0; i < inputs.size(); i++) { input_device_data_[i] = GetDeviceAddress(inputs, i); } - CHECK_CUDA_RET_WITH_EXCEPT( - kernel_node_, - cudaMemcpy(&input_host_data_[0], &input_device_data_[0], input_size_ * sizeof(T), cudaMemcpyDeviceToHost), - "cudaMemcpy output failed"); - for (size_t i = 0; i < input_num_.size(); i++) { - for (size_t j = 0; j < input_num_[i]; j++) { - std::cout << input_host_data_[i][j]; - } + int *output_address = GetDeviceAddress(outputs, 0); + // host initialization + std::vector > input_host_data; + for (size_t i = 0; i < input_size_.size(); i++) { + std::unique_ptr value = std::make_unique(input_size_[i]); + input_host_data.push_back(std::move(value)); + } + // check type + T type_value = static_cast(0.0f); + auto type_id = CheckType(type_value); + if (type_id == kTypeUnknown) { + MS_LOG(EXCEPTION) << "GPU print does not support the input type."; } + // print core function + for (size_t i = 0; i < input_host_data.size(); i++) { + std::string error_msg = "cudaMemcpy print loop failed at input_device_data["; + error_msg.append(std::to_string(i)); + error_msg.append("]."); + CHECK_CUDA_RET_WITH_EXCEPT( + kernel_node_, + cudaMemcpy(input_host_data[i].get(), input_device_data_[i], input_size_[i] * sizeof(T), cudaMemcpyDeviceToHost), + error_msg); + ShapeVector shape; + (void)std::transform(input_shape_[i].begin(), input_shape_[i].end(), std::back_inserter(shape), + [](const size_t &value) { return static_cast(value); }); + Tensor current_tensor(type_id, shape, input_host_data[i].get(), input_size_[i] * sizeof(T)); + std::cout << current_tensor.ToString() << std::endl; + } + int output = 1; + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, + cudaMemcpyAsync(output_address, &output, sizeof(int), cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync output failed"); return true; } @@ -57,38 +86,70 @@ class PrintGpuKernel : public GpuKernel { kernel_node_ = kernel_node; size_t input_tensor_num = AnfAlgo::GetInputTensorNum(kernel_node); input_device_data_ = std::make_unique(input_tensor_num); - input_host_data_ = std::make_unique(input_tensor_num); + std::vector value_shape; for (size_t i = 0; i < input_tensor_num; i++) { - size_t counter = 0; + size_t value = 1; auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, i); for (size_t j = 0; j < input_shape.size(); j++) { - input_size_ *= input_shape[j]; - counter++; + value *= input_shape[j]; + value_shape.push_back(input_shape[j]); } - input_num_.push_back(counter); + input_size_.push_back(value); + input_shape_.push_back(value_shape); + value_shape.clear(); } InitSizeLists(); return true; } void ResetResource() noexcept override { - input_size_ = 1; input_device_data_ = nullptr; - input_host_data_ = nullptr; - input_num_.clear(); + input_size_.clear(); + input_shape_.clear(); input_size_list_.clear(); output_size_list_.clear(); workspace_size_list_.clear(); } protected: - void InitSizeLists() override { input_size_list_.push_back(input_size_ * sizeof(T)); } + void InitSizeLists() override { + for (size_t i = 0; i < input_size_.size(); i++) { + input_size_list_.push_back(input_size_[i] * sizeof(T)); + } + output_size_list_.push_back(sizeof(int)); + } + + TypeId CheckType(T value) { + if (std::is_same::value) { + return kNumberTypeBool; + } else if (std::is_same::value) { + return kNumberTypeInt8; + } else if (std::is_same::value) { + return kNumberTypeInt16; + } else if (std::is_same::value) { + return kNumberTypeInt32; + } else if (std::is_same::value) { + return kNumberTypeInt64; + } else if (std::is_same::value) { + return kNumberTypeUInt8; + } else if (std::is_same::value) { + return kNumberTypeUInt16; + } else if (std::is_same::value) { + return kNumberTypeUInt32; + } else if (std::is_same::value) { + return kNumberTypeUInt64; + } else if (std::is_same::value) { + return kNumberTypeFloat16; + } else if (std::is_same::value) { + return kNumberTypeFloat32; + } + return kTypeUnknown; + } private: - size_t input_size_; std::unique_ptr input_device_data_; - std::unique_ptr input_host_data_; - std::vector input_num_; + std::vector input_size_; + std::vector > input_shape_; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; diff --git a/mindspore/ops/operations/debug_ops.py b/mindspore/ops/operations/debug_ops.py index baee3da117..1aca4c50e1 100644 --- a/mindspore/ops/operations/debug_ops.py +++ b/mindspore/ops/operations/debug_ops.py @@ -341,10 +341,11 @@ class Print(PrimitiveWithInfer): In pynative mode, please use python print function. In graph mode, the bool, int, float, tuple, and list would be converted into Tensor to print, str remains unchanged. + In GPU, all input elements should be the same type and string is not supported. Inputs: - **input_x** (Union[Tensor, bool, int, float, str, tuple, list]) - The graph node to attach to. - Supports multiple inputs which are separated by ','. + Supports multiple inputs which are separated by ','. GPU does not support string as an input. Supported Platforms: ``Ascend`` ``GPU`` diff --git a/tests/st/ops/gpu/test_print_op.py b/tests/st/ops/gpu/test_print_op.py new file mode 100644 index 0000000000..cd53976c75 --- /dev/null +++ b/tests/st/ops/gpu/test_print_op.py @@ -0,0 +1,135 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + + +import numpy as np +import pytest + +from mindspore import Tensor +import mindspore.nn as nn +from mindspore.ops import operations as P +import mindspore.context as context + + +class PrintNetOneInput(nn.Cell): + def __init__(self): + super(PrintNetOneInput, self).__init__() + self.op = P.Print() + + def construct(self, x): + self.op(x) + return x + + +class PrintNetTwoInputs(nn.Cell): + def __init__(self): + super(PrintNetTwoInputs, self).__init__() + self.op = P.Print() + + def construct(self, x, y): + self.op(x, y) + return x + + +def print_testcase(nptype): + # large shape + x = np.arange(20808).reshape(6, 3, 34, 34).astype(nptype) + # small shape + y = np.arange(9).reshape(3, 3).astype(nptype) + x = Tensor(x) + y = Tensor(y) + # graph mode + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + net_1 = PrintNetOneInput() + net_2 = PrintNetTwoInputs() + net_1(x) + net_2(x, y) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_print_bool(): + print_testcase(np.bool) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_print_int8(): + print_testcase(np.int8) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_print_int16(): + print_testcase(np.int16) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_print_int32(): + print_testcase(np.int32) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_print_int64(): + print_testcase(np.int64) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_print_uint8(): + print_testcase(np.uint8) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_print_uint16(): + print_testcase(np.uint16) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_print_uint32(): + print_testcase(np.uint32) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_print_uint64(): + print_testcase(np.uint64) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_print_float16(): + print_testcase(np.float16) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_print_float32(): + print_testcase(np.float32)