Rework GPU print, supporting pynative mode and graph mode

pull/12014/head
TFBunny 4 years ago
parent c16b45ab23
commit 5a8c899789

@ -19,37 +19,37 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(Print,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32),
PrintGpuKernel, bool)
MS_REG_GPU_KERNEL_ONE(Print,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32),
PrintGpuKernel, int8_t)
MS_REG_GPU_KERNEL_ONE(Print,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32),
PrintGpuKernel, int16_t)
MS_REG_GPU_KERNEL_ONE(Print,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
PrintGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(Print,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
PrintGpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(Print,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32),
PrintGpuKernel, uint8_t)
MS_REG_GPU_KERNEL_ONE(Print,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
PrintGpuKernel, bool)
MS_REG_GPU_KERNEL_ONE(
Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
PrintGpuKernel, uint16_t)
MS_REG_GPU_KERNEL_ONE(
Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
PrintGpuKernel, uint32_t)
MS_REG_GPU_KERNEL_ONE(
Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
PrintGpuKernel, uint64_t)
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32),
PrintGpuKernel, uint16_t)
MS_REG_GPU_KERNEL_ONE(Print,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32),
PrintGpuKernel, uint32_t)
MS_REG_GPU_KERNEL_ONE(Print,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32),
PrintGpuKernel, uint64_t)
MS_REG_GPU_KERNEL_ONE(
Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32),
PrintGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
Print, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
PrintGpuKernel, float)
} // namespace kernel
} // namespace mindspore

@ -17,11 +17,17 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DEBUG_PRINT_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DEBUG_PRINT_GPU_KERNEL_H_
#include <utility>
#include <string>
#include <algorithm>
#include <vector>
#include <memory>
#include "ir/tensor.h"
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
using mindspore::tensor::Tensor;
namespace mindspore {
namespace kernel {
template <typename T>
@ -37,19 +43,42 @@ class PrintGpuKernel : public GpuKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
VARIABLE_NOT_USED(workspace);
VARIABLE_NOT_USED(outputs);
for (size_t i = 0; i < inputs.size(); i++) {
input_device_data_[i] = GetDeviceAddress<T>(inputs, i);
}
CHECK_CUDA_RET_WITH_EXCEPT(
kernel_node_,
cudaMemcpy(&input_host_data_[0], &input_device_data_[0], input_size_ * sizeof(T), cudaMemcpyDeviceToHost),
"cudaMemcpy output failed");
for (size_t i = 0; i < input_num_.size(); i++) {
for (size_t j = 0; j < input_num_[i]; j++) {
std::cout << input_host_data_[i][j];
}
int *output_address = GetDeviceAddress<int>(outputs, 0);
// host initialization
std::vector<std::unique_ptr<T[]> > input_host_data;
for (size_t i = 0; i < input_size_.size(); i++) {
std::unique_ptr<T[]> value = std::make_unique<T[]>(input_size_[i]);
input_host_data.push_back(std::move(value));
}
// check type
T type_value = static_cast<T>(0.0f);
auto type_id = CheckType(type_value);
if (type_id == kTypeUnknown) {
MS_LOG(EXCEPTION) << "GPU print does not support the input type.";
}
// print core function
for (size_t i = 0; i < input_host_data.size(); i++) {
std::string error_msg = "cudaMemcpy print loop failed at input_device_data[";
error_msg.append(std::to_string(i));
error_msg.append("].");
CHECK_CUDA_RET_WITH_EXCEPT(
kernel_node_,
cudaMemcpy(input_host_data[i].get(), input_device_data_[i], input_size_[i] * sizeof(T), cudaMemcpyDeviceToHost),
error_msg);
ShapeVector shape;
(void)std::transform(input_shape_[i].begin(), input_shape_[i].end(), std::back_inserter(shape),
[](const size_t &value) { return static_cast<int64_t>(value); });
Tensor current_tensor(type_id, shape, input_host_data[i].get(), input_size_[i] * sizeof(T));
std::cout << current_tensor.ToString() << std::endl;
}
int output = 1;
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(output_address, &output, sizeof(int), cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync output failed");
return true;
}
@ -57,38 +86,70 @@ class PrintGpuKernel : public GpuKernel {
kernel_node_ = kernel_node;
size_t input_tensor_num = AnfAlgo::GetInputTensorNum(kernel_node);
input_device_data_ = std::make_unique<T *[]>(input_tensor_num);
input_host_data_ = std::make_unique<T *[]>(input_tensor_num);
std::vector<size_t> value_shape;
for (size_t i = 0; i < input_tensor_num; i++) {
size_t counter = 0;
size_t value = 1;
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, i);
for (size_t j = 0; j < input_shape.size(); j++) {
input_size_ *= input_shape[j];
counter++;
value *= input_shape[j];
value_shape.push_back(input_shape[j]);
}
input_num_.push_back(counter);
input_size_.push_back(value);
input_shape_.push_back(value_shape);
value_shape.clear();
}
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
input_size_ = 1;
input_device_data_ = nullptr;
input_host_data_ = nullptr;
input_num_.clear();
input_size_.clear();
input_shape_.clear();
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override { input_size_list_.push_back(input_size_ * sizeof(T)); }
void InitSizeLists() override {
for (size_t i = 0; i < input_size_.size(); i++) {
input_size_list_.push_back(input_size_[i] * sizeof(T));
}
output_size_list_.push_back(sizeof(int));
}
TypeId CheckType(T value) {
if (std::is_same<T, bool>::value) {
return kNumberTypeBool;
} else if (std::is_same<T, int8_t>::value) {
return kNumberTypeInt8;
} else if (std::is_same<T, int16_t>::value) {
return kNumberTypeInt16;
} else if (std::is_same<T, int>::value) {
return kNumberTypeInt32;
} else if (std::is_same<T, int64_t>::value) {
return kNumberTypeInt64;
} else if (std::is_same<T, uint8_t>::value) {
return kNumberTypeUInt8;
} else if (std::is_same<T, uint16_t>::value) {
return kNumberTypeUInt16;
} else if (std::is_same<T, uint32_t>::value) {
return kNumberTypeUInt32;
} else if (std::is_same<T, uint64_t>::value) {
return kNumberTypeUInt64;
} else if (std::is_same<T, half>::value) {
return kNumberTypeFloat16;
} else if (std::is_same<T, float>::value) {
return kNumberTypeFloat32;
}
return kTypeUnknown;
}
private:
size_t input_size_;
std::unique_ptr<T *[]> input_device_data_;
std::unique_ptr<T *[]> input_host_data_;
std::vector<size_t> input_num_;
std::vector<size_t> input_size_;
std::vector<std::vector<size_t> > input_shape_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;

@ -341,10 +341,11 @@ class Print(PrimitiveWithInfer):
In pynative mode, please use python print function.
In graph mode, the bool, int, float, tuple, and list would be converted into Tensor to print,
str remains unchanged.
In GPU, all input elements should be the same type and string is not supported.
Inputs:
- **input_x** (Union[Tensor, bool, int, float, str, tuple, list]) - The graph node to attach to.
Supports multiple inputs which are separated by ','.
Supports multiple inputs which are separated by ','. GPU does not support string as an input.
Supported Platforms:
``Ascend`` ``GPU``

@ -0,0 +1,135 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import Tensor
import mindspore.nn as nn
from mindspore.ops import operations as P
import mindspore.context as context
class PrintNetOneInput(nn.Cell):
def __init__(self):
super(PrintNetOneInput, self).__init__()
self.op = P.Print()
def construct(self, x):
self.op(x)
return x
class PrintNetTwoInputs(nn.Cell):
def __init__(self):
super(PrintNetTwoInputs, self).__init__()
self.op = P.Print()
def construct(self, x, y):
self.op(x, y)
return x
def print_testcase(nptype):
# large shape
x = np.arange(20808).reshape(6, 3, 34, 34).astype(nptype)
# small shape
y = np.arange(9).reshape(3, 3).astype(nptype)
x = Tensor(x)
y = Tensor(y)
# graph mode
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net_1 = PrintNetOneInput()
net_2 = PrintNetTwoInputs()
net_1(x)
net_2(x, y)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_print_bool():
print_testcase(np.bool)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_print_int8():
print_testcase(np.int8)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_print_int16():
print_testcase(np.int16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_print_int32():
print_testcase(np.int32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_print_int64():
print_testcase(np.int64)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_print_uint8():
print_testcase(np.uint8)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_print_uint16():
print_testcase(np.uint16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_print_uint32():
print_testcase(np.uint32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_print_uint64():
print_testcase(np.uint64)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_print_float16():
print_testcase(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_print_float32():
print_testcase(np.float32)
Loading…
Cancel
Save