Add the support of fp16 in fusion_group (#22239)
parent
d97475d53b
commit
22bbd54719
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,82 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
namespace fusion_group {
|
||||
|
||||
static constexpr char predefined_cuda_functions_fp32[] = R"(
|
||||
__device__ inline float real_exp(float x) { return ::expf(x); }
|
||||
__device__ inline float real_log(float x) { return ::logf(x); }
|
||||
|
||||
)";
|
||||
|
||||
static constexpr char predefined_cuda_functions_fp64[] = R"(
|
||||
__device__ inline double real_exp(double x) { return ::exp(x); }
|
||||
__device__ inline double real_log(double x) { return ::log(x); }
|
||||
|
||||
)";
|
||||
|
||||
static constexpr char predefined_cuda_functions_fp16[] = R"(
|
||||
__device__ inline float real_exp(float x) { return ::expf(x); }
|
||||
__device__ inline float real_log(float x) { return ::logf(x); }
|
||||
|
||||
#define __HALF_TO_US(var) *(reinterpret_cast<unsigned short *>(&(var)))
|
||||
#define __HALF_TO_CUS(var) *(reinterpret_cast<const unsigned short *>(&(var)))
|
||||
|
||||
struct __align__(2) __half {
|
||||
__device__ __half() { }
|
||||
|
||||
protected:
|
||||
unsigned short __x;
|
||||
};
|
||||
|
||||
__device__ __half __float2half(const float f) {
|
||||
__half val;
|
||||
asm("{ cvt.rn.f16.f32 %0, %1; }\n" : "=h"(__HALF_TO_US(val)
|
||||
|
||||
) : "f"(f));
|
||||
return val;
|
||||
}
|
||||
|
||||
__device__ float __half2float(const __half h) {
|
||||
float val;
|
||||
asm("{ cvt.f32.f16 %0, %1; }\n" : "=f"(val) : "h"(__HALF_TO_CUS(h)));
|
||||
return val;
|
||||
}
|
||||
|
||||
#undef __HALF_TO_US
|
||||
#undef __HALF_TO_CUS
|
||||
|
||||
typedef __half float16;
|
||||
|
||||
)";
|
||||
|
||||
static constexpr char cuda_kernel_template_1d[] = R"(
|
||||
extern "C" __global__ void $func_name($parameters) {
|
||||
for(int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
idx < N;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
$compute_body
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
} // namespace fusion_group
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
@ -1,109 +0,0 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from pass_test import PassTest
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.core as core
|
||||
|
||||
|
||||
class FusionGroupPassTest(PassTest):
|
||||
def setUp(self):
|
||||
with fluid.program_guard(self.main_program, self.startup_program):
|
||||
data1 = fluid.data(name="data1", shape=[32, 128], dtype="float32")
|
||||
data2 = fluid.data(name="data2", shape=[32, 128], dtype="float32")
|
||||
data3 = fluid.data(name="data3", shape=[32, 128], dtype="float32")
|
||||
tmp_1 = fluid.layers.elementwise_add(data1, data2)
|
||||
tmp_2 = fluid.layers.elementwise_mul(data3, tmp_1)
|
||||
|
||||
self.feeds = {
|
||||
"data1": np.random.random((32, 128)).astype("float32"),
|
||||
"data2": np.random.random((32, 128)).astype("float32"),
|
||||
"data3": np.random.random((32, 128)).astype("float32")
|
||||
}
|
||||
self.fetch_list = [tmp_1, tmp_2]
|
||||
self.pass_names = "fusion_group_pass"
|
||||
self.fused_op_type = "fusion_group"
|
||||
self.num_fused_ops = 1
|
||||
|
||||
def test_check_output(self):
|
||||
use_gpu_set = []
|
||||
if core.is_compiled_with_cuda():
|
||||
use_gpu_set.append(True)
|
||||
for use_gpu in use_gpu_set:
|
||||
self.pass_attrs = {"fusion_group_pass": {"use_gpu": use_gpu}}
|
||||
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
|
||||
self.check_output_with_place(place, startup_on_cpu=False)
|
||||
|
||||
|
||||
class FusionGroupPassTest1(FusionGroupPassTest):
|
||||
def setUp(self):
|
||||
with fluid.program_guard(self.main_program, self.startup_program):
|
||||
data = []
|
||||
for i in range(5):
|
||||
data.append(
|
||||
fluid.data(
|
||||
name=("data" + str(i)),
|
||||
shape=[32, 128],
|
||||
dtype="float32"))
|
||||
tmp_1 = (
|
||||
fluid.layers.assign(data[0]) * fluid.layers.sigmoid(data[1])
|
||||
) + (fluid.layers.sigmoid(data[2]) * fluid.layers.tanh(data[3]))
|
||||
tmp_2 = fluid.layers.tanh(tmp_1) + fluid.layers.sigmoid(data[4])
|
||||
|
||||
self.feeds = {}
|
||||
for i in range(5):
|
||||
self.feeds["data" + str(i)] = np.random.random(
|
||||
(32, 128)).astype("float32")
|
||||
|
||||
self.fetch_list = [tmp_1, tmp_2]
|
||||
self.pass_names = "fusion_group_pass"
|
||||
self.fused_op_type = "fusion_group"
|
||||
self.num_fused_ops = 1
|
||||
|
||||
|
||||
class FusionGroupPassTest2(FusionGroupPassTest):
|
||||
def setUp(self):
|
||||
with fluid.program_guard(self.main_program, self.startup_program):
|
||||
data = []
|
||||
for i in range(3):
|
||||
data.append(
|
||||
fluid.data(
|
||||
name=("data" + str(i)),
|
||||
shape=[32, 128],
|
||||
dtype="float32"))
|
||||
data.append(
|
||||
fluid.data(
|
||||
name="data3", shape=[128, 32], dtype="float32"))
|
||||
tmp_1 = fluid.layers.relu((data[0] - data[1]) * data[2])
|
||||
tmp_2 = fluid.layers.sigmoid(data[3])
|
||||
tmp_3 = fluid.layers.relu(tmp_2)
|
||||
tmp_4 = fluid.layers.mul(tmp_1, tmp_3)
|
||||
|
||||
self.feeds = {}
|
||||
for i in range(3):
|
||||
self.feeds["data" + str(i)] = np.random.random(
|
||||
(32, 128)).astype("float32")
|
||||
self.feeds["data3"] = np.random.random((128, 32)).astype("float32")
|
||||
|
||||
self.fetch_list = [tmp_1, tmp_2, tmp_3, tmp_4]
|
||||
self.pass_names = "fusion_group_pass"
|
||||
self.fused_op_type = "fusion_group"
|
||||
self.num_fused_ops = 2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -0,0 +1,142 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from pass_test import PassTest
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.layers as layers
|
||||
import paddle.fluid.core as core
|
||||
|
||||
|
||||
class FusionGroupPassTest(PassTest):
|
||||
def build_program(self, dtype):
|
||||
with fluid.program_guard(self.main_program, self.startup_program):
|
||||
self.feed_vars = self._prepare_feed_vars([32, 128], dtype, 2)
|
||||
self.feed_vars.append(
|
||||
fluid.data(
|
||||
name="data2", shape=[128, 128], dtype=dtype))
|
||||
|
||||
# subgraph with only 1 op node
|
||||
tmp_0 = self.feed_vars[0] * self.feed_vars[1]
|
||||
tmp_1 = layers.mul(tmp_0, self.feed_vars[2])
|
||||
# subgraph with 2 op nodes
|
||||
tmp_2 = layers.relu(tmp_0 + tmp_1)
|
||||
|
||||
self.fetch_list = [tmp_2]
|
||||
self.num_fused_ops = 1
|
||||
|
||||
def setUp(self):
|
||||
self.build_program("float32")
|
||||
self.feeds = self._feed_random_data(self.feed_vars)
|
||||
self.pass_names = "fusion_group_pass"
|
||||
self.fused_op_type = "fusion_group"
|
||||
|
||||
def _prepare_feed_vars(self, shape, dtype, num_data):
|
||||
feed_vars = []
|
||||
for i in range(num_data):
|
||||
var = fluid.data(name=("data" + str(i)), shape=shape, dtype=dtype)
|
||||
feed_vars.append(var)
|
||||
return feed_vars
|
||||
|
||||
def _feed_random_data(self, feed_vars):
|
||||
feeds = {}
|
||||
for var in feed_vars:
|
||||
if var.type != fluid.core.VarDesc.VarType.LOD_TENSOR:
|
||||
raise TypeError("Feed data of non LoDTensor is not supported.")
|
||||
|
||||
shape = var.shape
|
||||
if var.dtype == fluid.core.VarDesc.VarType.FP32:
|
||||
dtype = "float32"
|
||||
elif var.dtype == fluid.core.VarDesc.VarType.FP64:
|
||||
dtype = "float64"
|
||||
elif var.dtype == fluid.core.VarDesc.VarType.FP16:
|
||||
dtype = "float16"
|
||||
else:
|
||||
raise ValueError("Unsupported dtype %s" % var.dtype)
|
||||
feeds[var.name] = np.random.random(shape).astype(dtype)
|
||||
return feeds
|
||||
|
||||
def test_check_output(self):
|
||||
if core.is_compiled_with_cuda():
|
||||
self.pass_attrs = {"fusion_group_pass": {"use_gpu": True}}
|
||||
self.check_output_with_place(fluid.CUDAPlace(0))
|
||||
|
||||
|
||||
class FusionGroupPassTest1(FusionGroupPassTest):
|
||||
def build_program(self, dtype):
|
||||
with fluid.program_guard(self.main_program, self.startup_program):
|
||||
self.feed_vars = self._prepare_feed_vars([32, 128], dtype, 5)
|
||||
|
||||
tmp_0 = layers.assign(self.feed_vars[0])
|
||||
# subgraph with 9 op nodes
|
||||
tmp_1 = tmp_0 * layers.sigmoid(self.feed_vars[1]) + layers.sigmoid(
|
||||
self.feed_vars[2]) * layers.tanh(self.feed_vars[3])
|
||||
tmp_2 = layers.tanh(tmp_1) + layers.sigmoid(self.feed_vars[4])
|
||||
|
||||
self.fetch_list = [tmp_1, tmp_2]
|
||||
self.num_fused_ops = 1
|
||||
|
||||
|
||||
class FusionGroupPassTest2(FusionGroupPassTest):
|
||||
def build_program(self, dtype):
|
||||
with fluid.program_guard(self.main_program, self.startup_program):
|
||||
self.feed_vars = self._prepare_feed_vars([32, 128], dtype, 3)
|
||||
self.feed_vars.append(
|
||||
fluid.data(
|
||||
name="data3", shape=[128, 32], dtype=dtype))
|
||||
|
||||
# subgraph with 3 op nodes
|
||||
tmp_1 = layers.relu(
|
||||
(self.feed_vars[0] - self.feed_vars[1]) * self.feed_vars[2])
|
||||
# subgraph with 2 op nodes
|
||||
tmp_2 = layers.relu(layers.sigmoid(self.feed_vars[3]))
|
||||
tmp_3 = layers.mul(tmp_1, tmp_2)
|
||||
|
||||
self.fetch_list = [tmp_1, tmp_2, tmp_3]
|
||||
self.num_fused_ops = 2
|
||||
|
||||
|
||||
class FusionGroupPassTestFP64(FusionGroupPassTest):
|
||||
def setUp(self):
|
||||
self.build_program("float64")
|
||||
self.feeds = self._feed_random_data(self.feed_vars)
|
||||
self.pass_names = "fusion_group_pass"
|
||||
self.fused_op_type = "fusion_group"
|
||||
|
||||
|
||||
class FusionGroupPassTestFP16(FusionGroupPassTest):
|
||||
def build_program(self, dtype):
|
||||
with fluid.program_guard(self.main_program, self.startup_program):
|
||||
self.feed_vars = self._prepare_feed_vars([32, 128], dtype, 2)
|
||||
self.feed_vars.append(
|
||||
fluid.data(
|
||||
name="data2", shape=[128, 128], dtype=dtype))
|
||||
|
||||
# subgraph with only 1 op node
|
||||
tmp_0 = self.feed_vars[0] * self.feed_vars[1]
|
||||
tmp_1 = layers.mul(tmp_0, self.feed_vars[2])
|
||||
tmp_2 = layers.cast(tmp_0, dtype="float16")
|
||||
tmp_3 = layers.cast(tmp_1, dtype="float16")
|
||||
# subgraph with 2 op nodes
|
||||
tmp_4 = layers.relu(tmp_2 + tmp_3)
|
||||
tmp_5 = layers.cast(tmp_4, dtype=dtype)
|
||||
|
||||
self.fetch_list = [tmp_5]
|
||||
self.num_fused_ops = 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in new issue