MKLDNN implementation of batch normalization (#9904)

* Initial implementation of forward pass for MKLDNN batch norm

* Added attributes for MKLDNN batch norm

* MKLDNN batch norm forward pass passes unittest. Started working on backward

* Backward pass for MKLDNN batch norm added

* MKLDNN batch norm: scoring added to forward pass

* MKLDNN batch norm: bias as input added; handling AnyLayout when kernel is looked up

* MKLDNN batch norm: python unit tests added; mkldnn tests removed

* MKLDNN batch norm: changes required by cpplint

* MKLDNN batch norm: refactoring the operator

* MKLDNN batch norm: saved variance inversed in backward pass for correct execution of MKLDNN unit tests

* MKLDNN batch norm: refctoring, function for static/const cast to void* added

* MKLDNN batch norm: remove AnyLayout from batch norm

*  MKLDNN batch norm: only NCHW format is supported. Unittests refactored

* MKDNN batch norm: use_mkldnn added to attributes

* MKLDNN batch norm: AnyLayout removed from unittest

* MKLDNN batch norm: added CUDNN defines to batch norm

* MKLDNN batch norm: undefined data_format variable corrected

* MKLDNN batch norm: use_cudnn added, use of setUp method for configuring attributes

* MKLDNN batch norm: added use_cudnn attribute to batch norm operator

* MKLDNN batch norm: correcting batch norm unit tests for MKLDNN

* MKLDNN batch norm: MKLDNN tests moved to another file; reverting changes for saved variance not being inverted

* Change default layout to NCHW

* MKLDNN batch norm: init_kernel_type method added to unit tests

* MKLDNN batch norm: style changes

* MKLDNN batch norm: unit tests refactored

* MKLDNN batch norm: added use_mkldnn attribute to batch norm python interface
trainerSaveLoadParams
Tomasz Patejko 7 years ago committed by Tao Luo
parent 4fbde42cdf
commit 4a497b826d

File diff suppressed because it is too large Load Diff

@ -15,6 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/batch_norm_op.h"
#include <string>
#include "paddle/fluid/framework/data_layout.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle {
namespace operators {
@ -106,7 +109,18 @@ class BatchNormOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(bn_param_type, framework::ToDataType(
ctx.Input<Tensor>("Variance")->type()),
"Variance input should be of float type");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
framework::LibraryType library_{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
}
#endif
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library_);
}
};
@ -151,6 +165,9 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
"Variance of the current mini batch, "
"will apply to output when training")
.AsIntermediate();
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddComment(R"DOC(
Batch Normalization.
@ -349,8 +366,19 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
if (t == nullptr) {
PADDLE_THROW("can't find Y@GRAD");
}
return framework::OpKernelType(framework::ToDataType(t->type()),
ctx.GetPlace());
framework::LibraryType library_{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
}
#endif
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
layout, library_);
}
};
@ -474,6 +502,7 @@ class BatchNormGradMaker : public framework::SingleGradOpDescMaker {
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
op->SetInput("Scale", Input("Scale"));
op->SetInput("Bias", Input("Bias"));
op->SetInput("SavedMean", Output("SavedMean"));
op->SetInput("SavedVariance", Output("SavedVariance"));

@ -1496,6 +1496,7 @@ def batch_norm(input,
bias_attr=None,
data_layout='NCHW',
in_place=False,
use_mkldnn=False,
name=None,
moving_mean_name=None,
moving_variance_name=None,
@ -1574,9 +1575,12 @@ def batch_norm(input,
"SavedMean": saved_mean,
"SavedVariance": saved_variance
},
attrs={"momentum": momentum,
attrs={
"momentum": momentum,
"epsilon": epsilon,
"is_test": is_test})
"is_test": is_test,
"use_mkldnn": use_mkldnn
})
return helper.append_activation(batch_norm_out)

@ -0,0 +1,56 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
from op_test import OpTest
from paddle.fluid.framework import grad_var_name
from test_batch_norm_op import TestBatchNormOpInference, TestBatchNormOpTraining, _reference_training, _reference_grad
class TestMKLDNNBatchNormOpTraining(TestBatchNormOpTraining):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_formats = ["NCHW"]
def ref_forward_backward(self, x, y_grad, scale, bias, mean, variance,
epsilon, momentum, shape, data_layout):
# run forward
y, saved_mean, saved_variance = _reference_training(
x, scale, bias, epsilon, data_layout)
mean_out = saved_mean * (1. - momentum) + momentum * mean
variance_out = saved_variance * (1. - momentum) + momentum * variance
# run backward
x_grad, scale_grad, bias_grad = _reference_grad(
x, y_grad, scale, saved_mean, saved_variance, epsilon, data_layout)
return y, mean_out, variance_out, saved_mean, saved_variance, x_grad, scale_grad, bias_grad
class TestMKLDNNBatchNormOpInference(TestBatchNormOpInference):
def init_kernel_type(self):
self.use_mkldnn = True
def test_check_output(self):
place = core.CPUPlace()
data_format = "NCHW"
self.check_with_place(place, data_format, self.dtype, [2, 3, 4, 5])
if __name__ == '__main__':
unittest.main()

@ -158,6 +158,8 @@ def set_output_grad(scope, outputs, place, feed_dict=None):
class TestBatchNormOpInference(unittest.TestCase):
def setUp(self):
self.dtype = np.float32
self.use_mkldnn = False
self.init_kernel_type()
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
@ -230,6 +232,7 @@ class TestBatchNormOpInference(unittest.TestCase):
# attrs
is_test=True,
data_layout=data_layout,
use_mkldnn=self.use_mkldnn,
epsilon=epsilon)
batch_norm_op.run(scope, place)
@ -254,10 +257,15 @@ class TestBatchNormOpInference(unittest.TestCase):
[2, 3, 4, 5])
self.check_with_place(place, data_format, self.dtype, [2, 3])
def init_kernel_type(self):
pass
class TestFP16BatchNormOpInference(TestBatchNormOpInference):
def setUp(self):
self.dtype = np.float16
self.use_mkldnn = False
self.init_kernel_type()
def test_check_output(self):
places = []
@ -274,9 +282,28 @@ class TestFP16BatchNormOpInference(TestBatchNormOpInference):
class TestBatchNormOpTraining(unittest.TestCase):
def setUp(self):
self.use_mkldnn = False
self.data_formats = ["NCHW", "NHWC"]
self.init_kernel_type()
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
np.allclose(np.array(tensor), np_array, atol=atol)
def ref_forward_backward(self, x, y_grad, scale, bias, mean, variance,
epsilon, momentum, shape, data_layout):
# run forward
y, saved_mean, var_ref = _reference_training(x, scale, bias, epsilon,
data_layout)
mean_out = saved_mean * (1. - momentum) + momentum * mean
variance_out = var_ref * (1. - momentum) + momentum * variance
saved_variance = 1. / np.sqrt(var_ref + epsilon)
# run backward
x_grad, scale_grad, bias_grad = _reference_grad(
x, y_grad, scale, saved_mean, var_ref, epsilon, data_layout)
return y, mean_out, variance_out, saved_mean, saved_variance, x_grad, scale_grad, bias_grad
def test_forward_backward(self):
def test_with_place(place, data_layout, shape):
# attr
@ -295,16 +322,11 @@ class TestBatchNormOpTraining(unittest.TestCase):
mean = np.zeros(scale_shape).astype(np.float32)
variance = np.ones(scale_shape).astype(np.float32)
# run forward
y, saved_mean, var_ref = _reference_training(x, scale, bias,
epsilon, data_layout)
mean_out = saved_mean * (1. - momentum) + momentum * mean
variance_out = var_ref * (1. - momentum) + momentum * variance
saved_variance = 1. / np.sqrt(var_ref + epsilon)
# run backward
y_grad = np.random.random_sample(shape).astype(np.float32)
x_grad, scale_grad, bias_grad = _reference_grad(
x, y_grad, scale, saved_mean, var_ref, epsilon, data_layout)
y, mean_out, variance_out, saved_mean, saved_variance, x_grad, scale_grad, bias_grad = self.ref_forward_backward(
x, y_grad, scale, bias, mean, variance, epsilon, momentum,
shape, data_layout)
var_dict = locals()
var_dict['y@GRAD'] = y_grad
@ -344,7 +366,8 @@ class TestBatchNormOpTraining(unittest.TestCase):
"momentum": momentum,
"epsilon": epsilon,
"is_test": False,
"data_layout": data_layout
"data_layout": data_layout,
"use_mkldnn": self.use_mkldnn
})
block.create_var(name='y@GRAD', dtype='float32', shape=y.shape)
@ -387,13 +410,17 @@ class TestBatchNormOpTraining(unittest.TestCase):
print "op test forward passed: ", str(place), data_layout
places = [core.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
places.append(core.CUDAPlace(0))
for place in places:
for data_format in ["NCHW", "NHWC"]:
for data_format in self.data_formats:
test_with_place(place, data_format, [2, 3, 4, 5])
def init_kernel_type(self):
pass
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save