!281 bug fix in gpu aware quantizaiton ops
Merge pull request !281 from SanjayChan/origin/quant_liantiaopull/281/MERGE
commit
0ca4ceb73f
@ -0,0 +1,89 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.api import ms_function
|
||||
import mindspore.context as context
|
||||
|
||||
context.set_context(device_target='GPU')
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.op = P.BatchNormFold2(100000)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x, beta, gamma, batch_std, batch_mean, running_std, running_mean, current_step):
|
||||
return self.op(x, beta, gamma, batch_std, batch_mean, running_std, running_mean, current_step)
|
||||
|
||||
|
||||
class Net_gnd(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net_gnd, self).__init__()
|
||||
self.conv_mul = P.ConvMul(freeze_bn=100000)
|
||||
self.correct_add = P.CorrectionAdd(freeze_bn=100000)
|
||||
self.add_fold = P.AddFold()
|
||||
|
||||
@ms_function
|
||||
def construct(self, x, beta, gamma, batch_std, batch_mean, running_std, running_mean, current_step):
|
||||
out = self.conv_mul(x, batch_std, running_std, current_step)
|
||||
out = self.correct_add(out, gamma, batch_std, batch_mean,
|
||||
running_std, running_mean, current_step)
|
||||
out = self.add_fold(out, beta, gamma, batch_std, batch_mean)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_batchnrom_fold2():
|
||||
net = Net()
|
||||
c = 64
|
||||
freeze_bn = 100000
|
||||
x = np.random.uniform(-1, 1, size=[3, c, 32, 32]).astype('float32')
|
||||
beta = np.random.uniform(1, 2, size=[c]).astype('float32')
|
||||
gamma = np.random.uniform(1, 2, size=[c]).astype('float32')
|
||||
batch_std = np.random.uniform(1, 2, size=[c]).astype('float32')
|
||||
batch_mean = np.random.uniform(1, 2, size=[c]).astype('float32')
|
||||
running_std = np.random.uniform(1, 2, size=[c]).astype('float32')
|
||||
running_mean = np.random.uniform(1, 2, size=[c]).astype('float32')
|
||||
current_step = np.array([0]).astype('int32')
|
||||
output = net(Tensor(x), Tensor(beta), Tensor(gamma), Tensor(batch_std), Tensor(batch_mean),
|
||||
Tensor(running_std), Tensor(running_mean), Tensor(current_step))
|
||||
expect = (x + beta.reshape(-1, 1, 1) - (gamma * running_mean / running_std).reshape(-1, 1,
|
||||
1) if current_step >= freeze_bn else
|
||||
x * (running_std / batch_std).reshape(-1, 1, 1) + (beta - gamma * batch_mean / batch_std).reshape(-1, 1,
|
||||
1))
|
||||
error = np.ones(shape=expect.shape) * 1.0e-6
|
||||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
assert np.all(diff > error * -1)
|
||||
|
||||
current_step = np.array([100000]).astype('int32')
|
||||
output = net(Tensor(x), Tensor(beta), Tensor(gamma), Tensor(batch_std), Tensor(batch_mean), Tensor(running_std),
|
||||
Tensor(running_mean), Tensor(current_step))
|
||||
expect = (x + beta.reshape(-1, 1, 1) - (gamma * running_mean / running_std).reshape(-1, 1,
|
||||
1) if current_step >= freeze_bn else
|
||||
x * (batch_std / running_std).reshape(-1, 1, 1) + (beta - gamma * batch_mean / batch_std).reshape(-1, 1,
|
||||
1))
|
||||
error = np.ones(shape=expect.shape) * 1.0e-6
|
||||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
assert np.all(diff > error * -1)
|
@ -0,0 +1,96 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.api import ms_function
|
||||
import mindspore.context as context
|
||||
|
||||
context.set_context(device_target='GPU')
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.op = P.BatchNormFoldGrad(freeze_bn=10)
|
||||
|
||||
@ms_function
|
||||
def construct(self, d_batch_mean, d_batch_std, x, batch_mean, batch_std, current_step):
|
||||
dx = self.op(d_batch_mean, d_batch_std, x, batch_mean, batch_std, current_step)
|
||||
return dx
|
||||
|
||||
|
||||
def np_result(d_batch_mean, d_batch_std, x, batch_mean, batch_std):
|
||||
n = x.shape[0] * x.shape[2] * x.shape[3]
|
||||
dx = d_batch_mean.reshape(1, -1, 1, 1) / n + d_batch_std.reshape(1, -1, 1, 1) * (
|
||||
x - batch_mean.reshape(1, -1, 1, 1)) / batch_std.reshape(1, -1, 1, 1) / n
|
||||
return dx
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_batchnorm_fold_grad1():
|
||||
net = Net()
|
||||
c = 64
|
||||
x = np.random.uniform(1, 10, size=[3, c, 32, 32]).astype('float32')
|
||||
d_batch_mean = np.random.uniform(1, 10, size=[c]).astype('float32')
|
||||
d_batch_std = np.random.uniform(1, 10, size=[c]).astype('float32')
|
||||
batch_mean = np.random.uniform(1, 10, size=[c]).astype('float32')
|
||||
batch_std = np.random.uniform(1, 10, size=[c]).astype('float32')
|
||||
current_step = np.array([0]).astype('int32')
|
||||
dx = net(Tensor(d_batch_mean), Tensor(d_batch_std), Tensor(x), Tensor(batch_mean), Tensor(batch_std),
|
||||
Tensor(current_step))
|
||||
expect = np_result(d_batch_mean, d_batch_std, x, batch_mean, batch_std)
|
||||
assert np.allclose(dx.asnumpy(), expect, rtol=1.e-7, atol=1.e-7)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_batchnorm_fold_grad2():
|
||||
net = Net()
|
||||
c = 64
|
||||
x = np.random.uniform(1, 10, size=[1, c, 256, 256]).astype('float32')
|
||||
d_batch_mean = np.random.uniform(1, 10, size=[c]).astype('float32')
|
||||
d_batch_std = np.random.uniform(1, 10, size=[c]).astype('float32')
|
||||
batch_mean = np.random.uniform(1, 10, size=[c]).astype('float32')
|
||||
batch_std = np.random.uniform(1, 10, size=[c]).astype('float32')
|
||||
current_step = np.array([0]).astype('int32')
|
||||
dx = net(Tensor(d_batch_mean), Tensor(d_batch_std), Tensor(x), Tensor(batch_mean), Tensor(batch_std),
|
||||
Tensor(current_step))
|
||||
expect = np_result(d_batch_mean, d_batch_std, x, batch_mean, batch_std)
|
||||
assert np.allclose(dx.asnumpy(), expect, rtol=1.e-7, atol=1.e-7)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_batchnorm_fold_grad_freeze():
|
||||
net = Net()
|
||||
c = 64
|
||||
x = np.random.uniform(1, 10, size=[3, c, 32, 32]).astype('float32')
|
||||
d_batch_mean = np.random.uniform(1, 10, size=[c]).astype('float32')
|
||||
d_batch_std = np.random.uniform(1, 10, size=[c]).astype('float32')
|
||||
batch_mean = np.random.uniform(1, 10, size=[c]).astype('float32')
|
||||
batch_std = np.random.uniform(1, 10, size=[c]).astype('float32')
|
||||
current_step = np.array([10]).astype('int32')
|
||||
dx = net(Tensor(d_batch_mean), Tensor(d_batch_std), Tensor(x), Tensor(batch_mean), Tensor(batch_std),
|
||||
Tensor(current_step))
|
||||
expect = np.zeros_like(x)
|
||||
assert np.allclose(dx.asnumpy(), expect, rtol=1.e-7, atol=1.e-7)
|
@ -0,0 +1,116 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.api import ms_function
|
||||
import mindspore.context as context
|
||||
|
||||
context.set_context(device_target='GPU')
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.op = P.BatchNormFold(freeze_bn=10)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x, mean, variance, current_step):
|
||||
a, b, c, d = self.op(x, mean, variance, current_step)
|
||||
return a, b, c, d
|
||||
|
||||
|
||||
def np_result(x, mean, var, momentum, epsilon):
|
||||
np_mean = x.mean(axis=(0, 2, 3))
|
||||
np_var = x.var(axis=(0, 2, 3))
|
||||
n = x.shape[0] * x.shape[2] * x.shape[3]
|
||||
mean_update = momentum * np_mean + (1 - momentum) * mean
|
||||
var_update = momentum * np_var * n / (n - 1) + (1 - momentum) * var
|
||||
np_var = np.sqrt(np_var + epsilon)
|
||||
delay_mean = mean.copy()
|
||||
delay_std = np.sqrt(var + epsilon)
|
||||
return np_mean, np_var, mean_update, var_update, delay_mean, delay_std
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_batchnorm_fold():
|
||||
net = Net()
|
||||
c = 64
|
||||
x = np.random.uniform(1, 10, size=[3, c, 32, 32]).astype('float32')
|
||||
mean = np.random.uniform(1, 10, size=[c]).astype('float32')
|
||||
variance = np.random.uniform(1, 10, size=[c]).astype('float32')
|
||||
current_step = np.array([0]).astype('int32')
|
||||
ms_mean = Tensor(mean)
|
||||
ms_var = Tensor(variance)
|
||||
batch_mean, batch_var, delay_mean, delay_std = net(Tensor(x), ms_mean, ms_var,
|
||||
Tensor(current_step))
|
||||
|
||||
expect1, expect2, expect3, expect4, expect5, expect6 = np_result(x, mean, variance, 0.9, 1e-12)
|
||||
assert np.allclose(batch_mean.asnumpy(), expect1, rtol=1.e-7, atol=1.e-5)
|
||||
assert np.allclose(batch_var.asnumpy(), expect2, rtol=1.e-7, atol=1.e-5)
|
||||
assert np.allclose(ms_mean.asnumpy(), expect3, rtol=1.e-7, atol=1.e-5)
|
||||
assert np.allclose(ms_var.asnumpy(), expect4, rtol=1.e-7, atol=1.e-5)
|
||||
assert np.allclose(delay_mean.asnumpy(), expect5, rtol=1.e-7, atol=1.e-5)
|
||||
assert np.allclose(delay_std.asnumpy(), expect6, rtol=1.e-7, atol=1.e-5)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_batchnorm_fold2():
|
||||
net = Net()
|
||||
c = 64
|
||||
x = np.random.uniform(1, 10, size=[3, c, 512, 512]).astype('float32')
|
||||
mean = np.random.uniform(1, 10, size=[c]).astype('float32')
|
||||
variance = np.random.uniform(1, 10, size=[c]).astype('float32')
|
||||
current_step = np.array([0]).astype('int32')
|
||||
ms_mean = Tensor(mean)
|
||||
ms_var = Tensor(variance)
|
||||
batch_mean, batch_var, delay_mean, delay_std = net(Tensor(x), ms_mean, ms_var,
|
||||
Tensor(current_step))
|
||||
expect1, expect2, expect3, expect4, expect5, expect6 = np_result(x, mean, variance, 0.9, 1e-12)
|
||||
assert np.allclose(batch_mean.asnumpy(), expect1, rtol=1.e-7, atol=1.e-5)
|
||||
assert np.allclose(batch_var.asnumpy(), expect2, rtol=1.e-7, atol=1.e-5)
|
||||
assert np.allclose(ms_mean.asnumpy(), expect3, rtol=1.e-7, atol=1.e-5)
|
||||
assert np.allclose(delay_mean.asnumpy(), expect5, rtol=1.e-7, atol=1.e-5)
|
||||
assert np.allclose(delay_std.asnumpy(), expect6, rtol=1.e-7, atol=1.e-5)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_batchnorm_fold_freeze():
|
||||
net = Net()
|
||||
c = 64
|
||||
x = np.random.uniform(1, 10, size=[3, c, 32, 32]).astype('float32')
|
||||
mean = np.random.uniform(1, 10, size=[c]).astype('float32')
|
||||
variance = np.random.uniform(1, 10, size=[c]).astype('float32')
|
||||
current_step = np.array([10]).astype('int32')
|
||||
ms_mean = Tensor(mean)
|
||||
ms_var = Tensor(variance)
|
||||
batch_mean, batch_var, delay_mean, delay_std = net(Tensor(x), ms_mean, ms_var,
|
||||
Tensor(current_step))
|
||||
expect1, expect2, expect3, expect4, expect5, expect6 = np_result(x, mean, variance, 0.9, 1e-12)
|
||||
assert np.allclose(batch_mean.asnumpy(), np.zeros_like(mean), rtol=1.e-7, atol=1.e-5)
|
||||
assert np.allclose(batch_var.asnumpy(), np.ones_like(mean), rtol=1.e-7, atol=1.e-5)
|
||||
assert np.allclose(ms_mean.asnumpy(), mean, rtol=1.e-7, atol=1.e-5)
|
||||
assert np.allclose(ms_var.asnumpy(), variance, rtol=1.e-7, atol=1.e-5)
|
||||
assert np.allclose(delay_mean.asnumpy(), expect5, rtol=1.e-7, atol=1.e-5)
|
||||
assert np.allclose(delay_std.asnumpy(), expect6, rtol=1.e-7, atol=1.e-5)
|
@ -0,0 +1,55 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import os
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.api import ms_function
|
||||
import mindspore.context as context
|
||||
|
||||
|
||||
context.set_context(device_target='GPU')
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.op_w = P.CorrectionMulGrad()
|
||||
|
||||
@ms_function
|
||||
def construct(self, dy, x, batch_std, running_std):
|
||||
dx, d_batch_std = self.op_w(dy, x, batch_std, running_std)
|
||||
return dx, d_batch_std
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_correction_mul_grad():
|
||||
net = Net()
|
||||
co, ci, h, w = 64, 1, 32, 32
|
||||
dout = np.random.uniform(-0.1, 0.1, size=[co, ci, h, w]).astype('float32')
|
||||
x = np.random.uniform(1, 1, size=[co, ci, h, w]).astype('float32')
|
||||
batch_std = np.random.uniform(1, 10, size=[co]).astype('float32')
|
||||
running_std = np.random.uniform(1, 10, size=[co]).astype('float32')
|
||||
output = net(Tensor(dout), Tensor(x), Tensor(batch_std), Tensor(running_std))
|
||||
expect = [0, 0]
|
||||
expect[0] = (dout * np.reshape(batch_std / running_std, (co, 1, 1, 1)))
|
||||
expect[1] = (np.sum(dout * x, (1, 2, 3)) / running_std)
|
||||
for i, v in enumerate(output):
|
||||
assert (np.allclose(output[i].asnumpy(), expect[i], rtol=1.e-5, atol=1.e-5))
|
@ -0,0 +1,52 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.api import ms_function
|
||||
import mindspore.context as context
|
||||
|
||||
context.set_context(device_target='GPU')
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.op = P.CorrectionMul()
|
||||
|
||||
@ms_function
|
||||
def construct(self, x, batch_var, moving_var):
|
||||
return self.op(x, batch_var, moving_var)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_correction_mul():
|
||||
net = Net()
|
||||
co = 64
|
||||
x = np.random.uniform(-1, 1, size=[co, 64, 32, 32]).astype('float32')
|
||||
bv = np.random.uniform(1, 2, size=[co]).astype('float32')
|
||||
mv = np.random.uniform(1, 2, size=[co]).astype('float32')
|
||||
output = net(Tensor(x), Tensor(bv), Tensor(mv))
|
||||
expect = x * np.reshape(bv, (co, 1, 1, 1)) / np.reshape(mv, (co, 1, 1, 1))
|
||||
error = np.ones(shape=expect.shape) * 1.0e-5
|
||||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
assert np.all(diff > error * -1)
|
||||
assert (output.shape() == expect.shape)
|
Loading…
Reference in new issue