[oneDNN] Layer norm bf16 kernel (#28619)
parent
cdc4e6620d
commit
6d8d3d4c22
@ -0,0 +1,177 @@
|
|||||||
|
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/layer_norm_op.h"
|
||||||
|
#include "paddle/fluid/platform/mkldnn_reuse.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class LayerNormMKLDNNHandler
|
||||||
|
: public platform::MKLDNNHandlerT<T, dnnl::layer_normalization_forward> {
|
||||||
|
public:
|
||||||
|
LayerNormMKLDNNHandler(const std::vector<int64_t>& dims, const float& epsilon,
|
||||||
|
const dnnl::normalization_flags& flags,
|
||||||
|
const bool& is_test, const MKLDNNMemoryFormat fmt,
|
||||||
|
const platform::MKLDNNDeviceContext& dev_ctx,
|
||||||
|
platform::Place cpu_place,
|
||||||
|
const std::string& uniq_name)
|
||||||
|
: platform::MKLDNNHandlerT<T, dnnl::layer_normalization_forward>(
|
||||||
|
dev_ctx, dev_ctx.GetEngine(), cpu_place,
|
||||||
|
platform::CreateKey(dims, uniq_name)) {
|
||||||
|
if (!this->isCached()) {
|
||||||
|
auto md = dnnl::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
|
||||||
|
if (!is_test) {
|
||||||
|
// TODO(grygielski) Delete forcing stats_md after DNNL 1.2 is introduced
|
||||||
|
auto stats_md = dnnl::memory::desc(
|
||||||
|
{begin(dims), end(dims) - 1}, platform::MKLDNNGetDataType<float>(),
|
||||||
|
platform::MKLDNNFormatForSize(dims.size() - 1,
|
||||||
|
MKLDNNMemoryFormat::nchw));
|
||||||
|
this->AcquireForwardPrimitiveDescriptor(
|
||||||
|
dnnl::prop_kind::forward_training, md, stats_md, epsilon, flags);
|
||||||
|
} else {
|
||||||
|
this->AcquireForwardPrimitiveDescriptor(
|
||||||
|
dnnl::prop_kind::forward_inference, md, epsilon, flags);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<dnnl::memory> AcquireScaleShiftMemory() {
|
||||||
|
return this->AcquireMemoryFromPrimitive("@scaleshift_mem_p");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<dnnl::memory> AcquireScaleShiftMemory(
|
||||||
|
std::vector<float>& scaleshift_data) {
|
||||||
|
// scaleshift_data comes from temporary buffer so we need to copy it into
|
||||||
|
// created memory primitivie
|
||||||
|
auto scaleshift_mem = this->AcquireMemoryFromPrimitive(
|
||||||
|
this->fwd_pd_->weights_desc(), "@scaleshift_mem_p");
|
||||||
|
auto data_ptr = scaleshift_mem->get_data_handle();
|
||||||
|
std::size_t num_bytes = scaleshift_data.size() * sizeof(float);
|
||||||
|
std::memcpy(data_ptr, scaleshift_data.data(), num_bytes);
|
||||||
|
return scaleshift_mem;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<dnnl::memory> AcquireMeanMemory(framework::Tensor* mean) {
|
||||||
|
T* mean_data = mean->mutable_data<T>(this->place_,
|
||||||
|
this->fwd_pd_->mean_desc().get_size());
|
||||||
|
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(),
|
||||||
|
mean_data, "@mean_mem_p");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<dnnl::memory> AcquireVarianceMemory(
|
||||||
|
framework::Tensor* variance) {
|
||||||
|
T* variance_data = variance->mutable_data<T>(
|
||||||
|
this->place_, this->fwd_pd_->variance_desc().get_size());
|
||||||
|
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(),
|
||||||
|
variance_data, "@variance_mem_p");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
auto* x = ctx.Input<Tensor>("X");
|
||||||
|
auto* scale = ctx.Input<Tensor>("Scale");
|
||||||
|
auto* bias = ctx.Input<Tensor>("Bias");
|
||||||
|
auto* y = ctx.Output<Tensor>("Y");
|
||||||
|
|
||||||
|
const float epsilon = ctx.Attr<float>("epsilon");
|
||||||
|
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
|
||||||
|
const bool is_test = ctx.Attr<bool>("is_test");
|
||||||
|
|
||||||
|
auto& dev_ctx =
|
||||||
|
ctx.template device_context<platform::MKLDNNDeviceContext>();
|
||||||
|
|
||||||
|
auto src_tz = paddle::framework::vectorize(x->dims());
|
||||||
|
PADDLE_ENFORCE_EQ(begin_norm_axis, (src_tz.size() - 1),
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"MKL-DNN Layer Norm supports only last logical "
|
||||||
|
"axis:%d as begin_norm_axis.",
|
||||||
|
(src_tz.size() - 1)));
|
||||||
|
|
||||||
|
y->mutable_data<T>(ctx.GetPlace());
|
||||||
|
const bool with_scaleshift = (scale && bias);
|
||||||
|
dnnl::normalization_flags flags{};
|
||||||
|
|
||||||
|
if (with_scaleshift) {
|
||||||
|
flags |= dnnl::normalization_flags::use_scale_shift;
|
||||||
|
}
|
||||||
|
|
||||||
|
LayerNormMKLDNNHandler<T> handler(src_tz, epsilon, flags, is_test,
|
||||||
|
x->format(), dev_ctx, ctx.GetPlace(),
|
||||||
|
ctx.OutputName("Y"));
|
||||||
|
|
||||||
|
auto src_memory = handler.AcquireSrcMemory(x);
|
||||||
|
auto dst_memory = handler.AcquireDstMemory(y);
|
||||||
|
|
||||||
|
auto layer_norm_p = handler.AcquireForwardPrimitive();
|
||||||
|
|
||||||
|
dnnl::stream astream(dev_ctx.GetEngine());
|
||||||
|
std::unordered_map<int, dnnl::memory> args;
|
||||||
|
|
||||||
|
args.insert({DNNL_ARG_SRC, *src_memory});
|
||||||
|
args.insert({DNNL_ARG_DST, *dst_memory});
|
||||||
|
|
||||||
|
if (!is_test) {
|
||||||
|
auto* mean = ctx.Output<Tensor>("Mean");
|
||||||
|
auto* var = ctx.Output<Tensor>("Variance");
|
||||||
|
mean->mutable_data<T>(ctx.GetPlace());
|
||||||
|
var->mutable_data<T>(ctx.GetPlace());
|
||||||
|
|
||||||
|
auto mean_memory = handler.AcquireMeanMemory(mean);
|
||||||
|
auto variance_memory = handler.AcquireVarianceMemory(var);
|
||||||
|
|
||||||
|
args.insert({DNNL_ARG_MEAN, *mean_memory});
|
||||||
|
args.insert({DNNL_ARG_VARIANCE, *variance_memory});
|
||||||
|
}
|
||||||
|
|
||||||
|
auto scaleshift_memory = handler.AcquireScaleShiftMemory();
|
||||||
|
if (with_scaleshift) {
|
||||||
|
if (scaleshift_memory == nullptr || !is_test) {
|
||||||
|
auto scale_tz = paddle::framework::vectorize(scale->dims());
|
||||||
|
const unsigned int C = scale_tz[0];
|
||||||
|
|
||||||
|
// MKLDNN requires a single piece of memory for scale and shift/bias
|
||||||
|
// data
|
||||||
|
std::vector<float> scaleshift_data;
|
||||||
|
scaleshift_data.reserve(2 * C);
|
||||||
|
scaleshift_data.insert(scaleshift_data.begin(), scale->data<float>(),
|
||||||
|
scale->data<float>() + C);
|
||||||
|
|
||||||
|
scaleshift_data.insert(scaleshift_data.end(), bias->data<float>(),
|
||||||
|
bias->data<float>() + C);
|
||||||
|
|
||||||
|
scaleshift_memory = handler.AcquireScaleShiftMemory(scaleshift_data);
|
||||||
|
}
|
||||||
|
args.insert({DNNL_ARG_SCALE_SHIFT, *scaleshift_memory});
|
||||||
|
}
|
||||||
|
|
||||||
|
layer_norm_p->execute(astream, args);
|
||||||
|
astream.wait();
|
||||||
|
|
||||||
|
y->set_layout(DataLayout::kMKLDNN);
|
||||||
|
y->set_format(platform::GetMKLDNNFormat(*dst_memory));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
// TODO(jczaja): Enable FP32 when performance is good
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OP_KERNEL(layer_norm, MKLDNN, ::paddle::platform::CPUPlace,
|
||||||
|
ops::LayerNormMKLDNNOpKernel<paddle::platform::bfloat16>);
|
@ -0,0 +1,146 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# from paddle.fluid.tests.unittests.test_layer_norm_op import *
|
||||||
|
from __future__ import print_function
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from operator import mul
|
||||||
|
import paddle.fluid.core as core
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
from paddle import enable_static
|
||||||
|
from functools import reduce
|
||||||
|
|
||||||
|
from paddle.fluid.tests.unittests.mkldnn.test_layer_norm_mkldnn_op import TestLayerNormMKLDNNOp
|
||||||
|
from paddle.fluid.tests.unittests.mkldnn.test_layer_norm_mkldnn_op import _reference_layer_norm_naive
|
||||||
|
from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16
|
||||||
|
from paddle.fluid.tests.unittests.op_test import _set_use_system_allocator
|
||||||
|
|
||||||
|
np.random.random(123)
|
||||||
|
|
||||||
|
_set_use_system_allocator(True)
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(not core.supports_bfloat16(),
|
||||||
|
"place does not support BF16 evaluation")
|
||||||
|
class TestLayerNormBF16MKLDNNOp(TestLayerNormMKLDNNOp):
|
||||||
|
def __assert_close(self, tensor, np_array, msg, rtol=2e-02, atol=2):
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(
|
||||||
|
np.array(tensor), np_array, rtol=rtol, atol=atol), msg)
|
||||||
|
|
||||||
|
def check_forward(self,
|
||||||
|
shape,
|
||||||
|
begin_norm_axis,
|
||||||
|
with_scale_bias=True,
|
||||||
|
with_is_test=False):
|
||||||
|
# attr
|
||||||
|
epsilon = 0.00001
|
||||||
|
x_shape = shape
|
||||||
|
D = reduce(mul, x_shape[begin_norm_axis:len(x_shape)], 1)
|
||||||
|
scale_shape = [D]
|
||||||
|
|
||||||
|
np.random.seed(123)
|
||||||
|
x = np.random.random_sample(x_shape).astype(np.float32)
|
||||||
|
x_bf16 = convert_float_to_uint16(x)
|
||||||
|
|
||||||
|
if with_scale_bias:
|
||||||
|
scale = np.random.random_sample(scale_shape).astype(np.float32)
|
||||||
|
bias = np.random.random_sample(scale_shape).astype(np.float32)
|
||||||
|
else:
|
||||||
|
scale = np.array([])
|
||||||
|
bias = np.array([])
|
||||||
|
|
||||||
|
# reference forward & backward
|
||||||
|
y, mean, variance = _reference_layer_norm_naive(x, scale, bias, epsilon,
|
||||||
|
begin_norm_axis)
|
||||||
|
|
||||||
|
y_bf16 = convert_float_to_uint16(y)
|
||||||
|
|
||||||
|
var_dict = locals()
|
||||||
|
var_names = ['x_bf16', 'mean', 'variance', 'y_bf16']
|
||||||
|
if with_scale_bias:
|
||||||
|
var_names.append('scale')
|
||||||
|
var_names.append('bias')
|
||||||
|
ground_truth = {name: var_dict[name] for name in var_names}
|
||||||
|
|
||||||
|
program = fluid.Program()
|
||||||
|
with fluid.program_guard(program):
|
||||||
|
block = program.global_block()
|
||||||
|
|
||||||
|
# scale and bias are fp32 and other vars are of bf16
|
||||||
|
for name in ground_truth:
|
||||||
|
if name == 'x_bf16' or name == 'y_bf16':
|
||||||
|
block.create_var(
|
||||||
|
name=name,
|
||||||
|
dtype='uint16',
|
||||||
|
shape=ground_truth[name].shape)
|
||||||
|
else:
|
||||||
|
block.create_var(
|
||||||
|
name=name,
|
||||||
|
dtype='float32',
|
||||||
|
shape=ground_truth[name].shape)
|
||||||
|
|
||||||
|
inputs = {"X": block.var('x_bf16')}
|
||||||
|
if with_scale_bias:
|
||||||
|
inputs["Scale"] = block.var('scale')
|
||||||
|
inputs["Bias"] = block.var('bias')
|
||||||
|
|
||||||
|
block.append_op(
|
||||||
|
type="layer_norm",
|
||||||
|
inputs=inputs,
|
||||||
|
outputs={
|
||||||
|
"Y": block.var('y_bf16'),
|
||||||
|
"Mean": block.var('mean'), # share the same memory
|
||||||
|
"Variance": block.var('variance'), # share the same memory
|
||||||
|
},
|
||||||
|
attrs={
|
||||||
|
"epsilon": epsilon,
|
||||||
|
"begin_norm_axis": begin_norm_axis,
|
||||||
|
"use_mkldnn": True,
|
||||||
|
"is_test": with_is_test
|
||||||
|
})
|
||||||
|
|
||||||
|
exe = fluid.Executor(core.CPUPlace())
|
||||||
|
|
||||||
|
input_list = ['x_bf16']
|
||||||
|
if with_scale_bias:
|
||||||
|
input_list.append('scale')
|
||||||
|
input_list.append('bias')
|
||||||
|
|
||||||
|
out = exe.run(program,
|
||||||
|
feed={name: var_dict[name]
|
||||||
|
for name in input_list},
|
||||||
|
fetch_list=['y_bf16', 'mean', 'variance'])
|
||||||
|
self.__assert_close(y_bf16, out[0], "y_bf16", 2)
|
||||||
|
if not with_is_test:
|
||||||
|
self.__assert_close(mean, out[1], "mean")
|
||||||
|
self.__assert_close(variance, out[2], "variance", 1e-3)
|
||||||
|
|
||||||
|
def test_check_forward_with_is_test(self):
|
||||||
|
self.check_forward(
|
||||||
|
shape=[2, 3, 4, 5], begin_norm_axis=3, with_is_test=True)
|
||||||
|
|
||||||
|
# TODO (jczaja): Enable those to test when enabling training using bf16
|
||||||
|
def test_check_forward_with_scale_and_bias(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_check_forward_without_scale_and_bias(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
enable_static()
|
||||||
|
unittest.main()
|
@ -0,0 +1,151 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# from paddle.fluid.tests.unittests.test_layer_norm_op import *
|
||||||
|
from __future__ import print_function
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from operator import mul
|
||||||
|
import paddle.fluid.core as core
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
from paddle import enable_static
|
||||||
|
from functools import reduce
|
||||||
|
|
||||||
|
from paddle.fluid.tests.unittests.op_test import _set_use_system_allocator
|
||||||
|
|
||||||
|
np.random.random(123)
|
||||||
|
|
||||||
|
_set_use_system_allocator(True)
|
||||||
|
|
||||||
|
|
||||||
|
def _reference_layer_norm_naive(x, scale, beta, epsilon, begin_norm_axis=1):
|
||||||
|
x_shape = x.shape
|
||||||
|
N = reduce(mul, x_shape[0:begin_norm_axis], 1)
|
||||||
|
D = reduce(mul, x_shape[begin_norm_axis:len(x_shape)], 1)
|
||||||
|
x.shape = [N, D]
|
||||||
|
if scale.size == 0 and beta.size == 0:
|
||||||
|
scale = np.ones([1, D])
|
||||||
|
beta = np.zeros([1, D])
|
||||||
|
else:
|
||||||
|
scale = scale.reshape([1, D])
|
||||||
|
beta = beta.reshape([1, D])
|
||||||
|
|
||||||
|
mean = np.mean(x, axis=1)
|
||||||
|
var = np.var(x, axis=1) + epsilon
|
||||||
|
output = scale * np.divide((x - mean.reshape([N, 1])),
|
||||||
|
(np.sqrt(var)).reshape([N, 1])) + beta
|
||||||
|
|
||||||
|
x.shape, output.shape = x_shape, x_shape
|
||||||
|
return output, mean, var
|
||||||
|
|
||||||
|
|
||||||
|
class TestLayerNormMKLDNNOp(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.use_mkldnn = True
|
||||||
|
|
||||||
|
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
|
||||||
|
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
|
||||||
|
|
||||||
|
def check_forward(self,
|
||||||
|
shape,
|
||||||
|
begin_norm_axis,
|
||||||
|
with_scale_bias=True,
|
||||||
|
with_is_test=False):
|
||||||
|
# attr
|
||||||
|
epsilon = 0.00001
|
||||||
|
x_shape = shape
|
||||||
|
D = reduce(mul, x_shape[begin_norm_axis:len(x_shape)], 1)
|
||||||
|
scale_shape = [D]
|
||||||
|
|
||||||
|
np.random.seed(123)
|
||||||
|
x = np.random.random_sample(x_shape).astype(np.float32)
|
||||||
|
|
||||||
|
if with_scale_bias:
|
||||||
|
scale = np.random.random_sample(scale_shape).astype(np.float32)
|
||||||
|
bias = np.random.random_sample(scale_shape).astype(np.float32)
|
||||||
|
else:
|
||||||
|
scale = np.array([])
|
||||||
|
bias = np.array([])
|
||||||
|
|
||||||
|
# reference forward & backward
|
||||||
|
y, mean, variance = _reference_layer_norm_naive(x, scale, bias, epsilon,
|
||||||
|
begin_norm_axis)
|
||||||
|
|
||||||
|
var_dict = locals()
|
||||||
|
var_names = ['x', 'mean', 'variance', 'y']
|
||||||
|
if with_scale_bias:
|
||||||
|
var_names.append('scale')
|
||||||
|
var_names.append('bias')
|
||||||
|
ground_truth = {name: var_dict[name] for name in var_names}
|
||||||
|
|
||||||
|
program = fluid.Program()
|
||||||
|
with fluid.program_guard(program):
|
||||||
|
block = program.global_block()
|
||||||
|
|
||||||
|
for name in ground_truth:
|
||||||
|
block.create_var(
|
||||||
|
name=name, dtype='float32', shape=ground_truth[name].shape)
|
||||||
|
|
||||||
|
inputs = {"X": block.var('x')}
|
||||||
|
if with_scale_bias:
|
||||||
|
inputs["Scale"] = block.var('scale')
|
||||||
|
inputs["Bias"] = block.var('bias')
|
||||||
|
|
||||||
|
block.append_op(
|
||||||
|
type="layer_norm",
|
||||||
|
inputs=inputs,
|
||||||
|
outputs={
|
||||||
|
"Y": block.var('y'),
|
||||||
|
"Mean": block.var('mean'), # share the same memory
|
||||||
|
"Variance": block.var('variance'), # share the same memory
|
||||||
|
},
|
||||||
|
attrs={
|
||||||
|
"epsilon": epsilon,
|
||||||
|
"begin_norm_axis": begin_norm_axis,
|
||||||
|
"use_mkldnn": True,
|
||||||
|
"is_test": with_is_test
|
||||||
|
})
|
||||||
|
|
||||||
|
exe = fluid.Executor(core.CPUPlace())
|
||||||
|
|
||||||
|
input_list = ['x']
|
||||||
|
if with_scale_bias:
|
||||||
|
input_list.append('scale')
|
||||||
|
input_list.append('bias')
|
||||||
|
|
||||||
|
out = exe.run(program,
|
||||||
|
feed={name: var_dict[name]
|
||||||
|
for name in input_list},
|
||||||
|
fetch_list=['y', 'mean', 'variance'])
|
||||||
|
self.__assert_close(y, out[0], "y")
|
||||||
|
if not with_is_test:
|
||||||
|
self.__assert_close(mean, out[1], "mean")
|
||||||
|
self.__assert_close(variance, out[2], "variance", 1e-3)
|
||||||
|
|
||||||
|
def test_check_forward_with_scale_and_bias(self):
|
||||||
|
self.check_forward(shape=[2, 3, 4, 5], begin_norm_axis=3)
|
||||||
|
|
||||||
|
def test_check_forward_without_scale_and_bias(self):
|
||||||
|
self.check_forward(
|
||||||
|
shape=[2, 3, 4, 5], begin_norm_axis=3, with_scale_bias=False)
|
||||||
|
|
||||||
|
def test_check_forward_with_is_test(self):
|
||||||
|
self.check_forward(
|
||||||
|
shape=[2, 3, 4, 5], begin_norm_axis=3, with_is_test=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
enable_static()
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue