MKLDNN implementation of batch normalization (#9904)
* Initial implementation of forward pass for MKLDNN batch norm * Added attributes for MKLDNN batch norm * MKLDNN batch norm forward pass passes unittest. Started working on backward * Backward pass for MKLDNN batch norm added * MKLDNN batch norm: scoring added to forward pass * MKLDNN batch norm: bias as input added; handling AnyLayout when kernel is looked up * MKLDNN batch norm: python unit tests added; mkldnn tests removed * MKLDNN batch norm: changes required by cpplint * MKLDNN batch norm: refactoring the operator * MKLDNN batch norm: saved variance inversed in backward pass for correct execution of MKLDNN unit tests * MKLDNN batch norm: refctoring, function for static/const cast to void* added * MKLDNN batch norm: remove AnyLayout from batch norm * MKLDNN batch norm: only NCHW format is supported. Unittests refactored * MKDNN batch norm: use_mkldnn added to attributes * MKLDNN batch norm: AnyLayout removed from unittest * MKLDNN batch norm: added CUDNN defines to batch norm * MKLDNN batch norm: undefined data_format variable corrected * MKLDNN batch norm: use_cudnn added, use of setUp method for configuring attributes * MKLDNN batch norm: added use_cudnn attribute to batch norm operator * MKLDNN batch norm: correcting batch norm unit tests for MKLDNN * MKLDNN batch norm: MKLDNN tests moved to another file; reverting changes for saved variance not being inverted * Change default layout to NCHW * MKLDNN batch norm: init_kernel_type method added to unit tests * MKLDNN batch norm: style changes * MKLDNN batch norm: unit tests refactored * MKLDNN batch norm: added use_mkldnn attribute to batch norm python interfacetrainerSaveLoadParams
parent
4fbde42cdf
commit
4a497b826d
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,56 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import paddle.fluid.core as core
|
||||
from paddle.fluid.op import Operator
|
||||
import paddle.fluid as fluid
|
||||
from op_test import OpTest
|
||||
from paddle.fluid.framework import grad_var_name
|
||||
from test_batch_norm_op import TestBatchNormOpInference, TestBatchNormOpTraining, _reference_training, _reference_grad
|
||||
|
||||
|
||||
class TestMKLDNNBatchNormOpTraining(TestBatchNormOpTraining):
|
||||
def init_kernel_type(self):
|
||||
self.use_mkldnn = True
|
||||
self.data_formats = ["NCHW"]
|
||||
|
||||
def ref_forward_backward(self, x, y_grad, scale, bias, mean, variance,
|
||||
epsilon, momentum, shape, data_layout):
|
||||
# run forward
|
||||
y, saved_mean, saved_variance = _reference_training(
|
||||
x, scale, bias, epsilon, data_layout)
|
||||
mean_out = saved_mean * (1. - momentum) + momentum * mean
|
||||
variance_out = saved_variance * (1. - momentum) + momentum * variance
|
||||
# run backward
|
||||
x_grad, scale_grad, bias_grad = _reference_grad(
|
||||
x, y_grad, scale, saved_mean, saved_variance, epsilon, data_layout)
|
||||
|
||||
return y, mean_out, variance_out, saved_mean, saved_variance, x_grad, scale_grad, bias_grad
|
||||
|
||||
|
||||
class TestMKLDNNBatchNormOpInference(TestBatchNormOpInference):
|
||||
def init_kernel_type(self):
|
||||
self.use_mkldnn = True
|
||||
|
||||
def test_check_output(self):
|
||||
place = core.CPUPlace()
|
||||
data_format = "NCHW"
|
||||
|
||||
self.check_with_place(place, data_format, self.dtype, [2, 3, 4, 5])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue