fix pylint warnings

pull/1348/head
Yi Huaijie 5 years ago
parent bd845dd0b7
commit 14fe72f383

@ -13,9 +13,8 @@
# limitations under the License.
# ============================================================================
import numpy as np
import os
import pytest
import numpy as np
import mindspore as ms
import mindspore.communication.management as distributedTool
@ -58,11 +57,12 @@ class Onehot(Cell):
self.off_value = Tensor(off_value, ms.float32)
self.transpose = P.Transpose().set_strategy(strategy=trans_stra)
self.sub = P.Sub().set_strategy(strategy=((1, 1), (1, 1)))
self.axis = axis
def construct(self, input, indices):
def construct(self, input_, indices):
x = self.onehot(indices, self.depth, self.on_value, self.off_value)
x = self.transpose(x, (1, 0))
x = self.sub(input, x)
x = self.sub(input_, x)
return x
@ -100,9 +100,9 @@ class DataGenerator():
class OneHotFactory:
def __init__(self, batch_size, classes, on_value=1.0, off_value=0.0, axis=None, strategy=None):
dataGen = DataGenerator()
self.input_full, self.input_part = dataGen.input_data((classes, batch_size))
self.label_full, self.label_part = dataGen.label_data((batch_size,), classes)
data_gen = DataGenerator()
self.input_full, self.input_part = data_gen.input_data((classes, batch_size))
self.label_full, self.label_part = data_gen.label_data((batch_size,), classes)
self.depth = classes
self.on_value = on_value
self.off_value = off_value

@ -13,23 +13,21 @@
# limitations under the License.
# ============================================================================
import numpy as np
import os
import pytest
import numpy as np
from numpy import allclose
import mindspore as ms
import mindspore.communication.management as distributedTool
from mindspore import context
from mindspore.common import dtype as mstype
from mindspore.common.parameter import ParameterTuple, Parameter
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
from mindspore.nn import Cell
from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.train import Model, ParallelMode
from mindspore.train import Model
from mindspore.train.callback import Callback
np.set_printoptions(threshold=np.inf)
@ -86,8 +84,8 @@ class DataGenerator():
datas = self.get_parallel_blocks(data, stra)
return Tensor(data), Tensor(datas[rank_id])
def label_data(self, shape, embed):
data = (self.generate_data(shape) * (embed - 1)).astype(np.int32)
def label_data(self, shape, embed_):
data = (self.generate_data(shape) * (embed_ - 1)).astype(np.int32)
stra = [1] * len(shape)
stra[0] = device_num
datas = self.get_parallel_blocks(data, stra)
@ -110,9 +108,8 @@ class Dataset():
raise StopIteration
self.index += 1
if self.input_num == 2:
return self.predict, self.label
else:
return self.predict,
return (self.predict, self.label)
return (self.predict,)
def reset(self):
self.index = 0
@ -129,15 +126,17 @@ class ModelCallback(Callback):
super(ModelCallback, self).__init__()
self.loss_list = []
def epoch_end(self, run_context, *args):
def epoch_end(self, run_context):
cb_params = run_context.original_args()
result = cb_params.net_outputs
self.loss_list.append(result.asnumpy().mean())
class SoftmaxCrossEntropyExpand(Cell):
def __init__(self, sparse=False, stra_list=[]):
def __init__(self, sparse=False, stra_list=None):
super(SoftmaxCrossEntropyExpand, self).__init__()
if stra_list is None:
stra_list = []
if len(stra_list) < 11:
stra_list = [None] * 11
self.exp = P.Exp()
@ -171,8 +170,10 @@ class SoftmaxCrossEntropyExpand(Cell):
class MatmulNet(Cell):
def __init__(self, matmul_stra=None, loss_stra_list=[]):
def __init__(self, matmul_stra=None, loss_stra_list=None):
super(MatmulNet, self).__init__()
if loss_stra_list is None:
loss_stra_list = []
self.matmul = P.MatMul(transpose_b=True).set_strategy(strategy=matmul_stra)
self.loss = SoftmaxCrossEntropyExpand(sparse=True, stra_list=loss_stra_list)
self.weight = Parameter(Tensor(np.ones(MatmulParamShape), dtype=ms.float32), name="weight")
@ -185,9 +186,9 @@ class MatmulNet(Cell):
class LossFactory():
def __init__(self):
dataGen = DataGenerator()
self.input_full, self.input_part = dataGen.input_data((batch_size, embed))
self.label_full, self.label_part = dataGen.label_data((batch_size,), embed)
data_gen = DataGenerator()
self.input_full, self.input_part = data_gen.input_data((batch_size, embed))
self.label_full, self.label_part = data_gen.label_data((batch_size,), embed)
def single_matmul_trains(self):
single_callback = ModelCallback()

@ -23,4 +23,4 @@ import pytest
def test_expand_loss():
sh_path = os.path.split(os.path.realpath(__file__))[0]
ret = os.system(f"sh {sh_path}/run_auto_parallel_loss_expand.sh")
assert (ret == 0)
assert ret == 0

@ -14,9 +14,8 @@
# ============================================================================
import os
import pytest
def test_expand_loss():
ret = os.system("sh run_onehot_model_parallel.sh")
assert (ret == 0)
assert ret == 0

@ -13,8 +13,8 @@
# limitations under the License.
# ============================================================================
import numpy as np
import os
import numpy as np
import pytest
import mindspore.common.dtype as mstype
@ -37,31 +37,29 @@ init()
context.set_auto_parallel_context(mirror_mean=True, parallel_mode=ParallelMode.AUTO_PARALLEL)
def weight_variable(shape, factor=0.1):
def weight_variable():
return One()
def _conv3x3(in_channels, out_channels, stride=1, padding=0, pad_mode='same'):
init_value = weight_variable((out_channels, in_channels, 3, 3))
init_value = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=3, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value)
def _conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same'):
init_value = weight_variable((out_channels, in_channels, 1, 1))
init_value = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=1, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value)
def _conv7x7(in_channels, out_channels, stride=1, padding=0, pad_mode='same'):
init_value = weight_variable((out_channels, in_channels, 7, 7))
init_value = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=7, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value)
def _fused_bn(channels, momentum=0.9):
init_weight = weight_variable((channels,))
init_bias = weight_variable((channels,))
return nn.BatchNorm2d(channels, momentum=momentum)
@ -210,8 +208,8 @@ class ResNet(nn.Cell):
self.mean = P.ReduceMean(keep_dims=True)
self.end_point = nn.Dense(2048, num_classes, has_bias=True,
weight_init=weight_variable((num_classes, 2048)),
bias_init=weight_variable((num_classes,)))
weight_init=weight_variable(),
bias_init=weight_variable())
self.squeeze = P.Squeeze()
self.cast = P.Cast()
@ -345,9 +343,8 @@ class Dataset():
raise StopIteration
self.index += 1
if self.input_num == 2:
return self.predict, self.label
else:
return self.predict,
return (self.predict, self.label)
return (self.predict,)
def reset(self):
self.index = 0
@ -364,7 +361,7 @@ class ModelCallback(Callback):
super(ModelCallback, self).__init__()
self.loss_list = []
def epoch_end(self, run_context, *args):
def epoch_end(self, run_context):
cb_params = run_context.original_args()
result = cb_params.net_outputs
self.loss_list.append(result.asnumpy().mean())
@ -376,9 +373,9 @@ class ModelCallback(Callback):
def test_train_feed(num_classes=8192):
set_algo_parameters(elementwise_op_strategy_follow=True)
parallel_callback = ModelCallback()
dataGen = DataGenerator()
input_full, input_part = dataGen.input_data((32 * 2, 3, 224, 224))
label_full, label_part = dataGen.label_data((32 * 2,))
data_gen = DataGenerator()
_, input_part = data_gen.input_data((32 * 2, 3, 224, 224))
_, label_part = data_gen.label_data((32 * 2,))
dataset = Dataset(input_part, label_part)
net = resnet50(num_classes)
loss = SoftmaxCrossEntropyExpand(sparse=True)
@ -396,9 +393,9 @@ def test_train_feed(num_classes=8192):
def test_train_feed2(num_classes=1001):
set_algo_parameters(elementwise_op_strategy_follow=True)
parallel_callback = ModelCallback()
dataGen = DataGenerator()
input_full, input_part = dataGen.input_data((32 * 2, 3, 224, 224))
label_full, label_part = dataGen.label_data((32 * 2,))
data_gen = DataGenerator()
_, input_part = data_gen.input_data((32 * 2, 3, 224, 224))
_, label_part = data_gen.label_data((32 * 2,))
dataset = Dataset(input_part, label_part)
net = resnet50(num_classes)
loss = SoftmaxCrossEntropyExpand(sparse=True)

@ -25,7 +25,6 @@ from mindspore.nn import Dense
from mindspore.nn import Momentum
from mindspore.nn import ReLU
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.ops.operations import Split
from mindspore.ops.operations.comm_ops import AllReduce, AllGather, _AlltoAll, ReduceOp, ReduceScatter
from mindspore.ops.operations.comm_ops import Broadcast

@ -16,8 +16,8 @@
@File : test_data_parallel_lenet.py
@Desc : test data parallel lenet
"""
import numpy as np
import os
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
@ -80,7 +80,6 @@ def test_lenet5_train_step_training_pynative():
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
device_num=8, mirror_mean=True)
size = 3
predict = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.zeros([1, 10]).astype(np.float32))
DatasetLenet(predict, label, 2)

@ -19,7 +19,7 @@ from mindspore.parallel._utils import _reset_op_id
from mindspore.parallel.algo_parameter_config import reset_algo_parameters
def setup_module(module):
def setup_module():
auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
reset_cost_model_context()

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
import numpy as np
import pytest
import mindspore as ms
@ -155,9 +155,9 @@ class AddReluFactory:
def grad_cmp(self):
input_grad_mindspore = self.grad_mindspore_impl()
input_grad_mindspore_parallel = self.grad_mindspore_parallel_impl()
input_grad_mindspore0 = input_grad_mindspore[0].asnumpy()
_ = input_grad_mindspore[0].asnumpy()
input_grad_mindspore1 = input_grad_mindspore[1].asnumpy()
input_grad_mindspore_parallel0 = input_grad_mindspore_parallel[0].asnumpy()
_ = input_grad_mindspore_parallel[0].asnumpy()
input_grad_mindspore_parallel1 = input_grad_mindspore_parallel[1].asnumpy()
assert np.allclose(input_grad_mindspore1, input_grad_mindspore_parallel1, 0.0001, 0.0001)

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
import numpy as np
from numpy import allclose
import mindspore.communication.management as distributedTool
@ -273,7 +273,7 @@ class Conv2dFactory:
stride=self.stride, pad_mode=self.pad_mode,
padding=self.padding, dilation=self.dilation,
group=self.group, has_bias=True, weight_init=weight,
bias_init=bias, )
bias_init=bias,)
else:
net = Conv2d(in_channels=self.in_c, out_channels=self.out_c,
kernel_size=(self.kernel_h, self.kernel_w),

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
import numpy as np
import mindspore as ms
import mindspore.communication.management as distributedTool
@ -45,8 +45,8 @@ class Net(Cell):
super(Net, self).__init__()
self.drop = Dropout(keep_prob, seed0, seed1, dtype=ms.float32, strategy=strategy)
def construct(self, input):
x = self.drop(input)
def construct(self, input_):
x = self.drop(input_)
return x
@ -83,16 +83,16 @@ class DropoutFactory:
i += 1
return blocks
def d4_tensor_compare(self, input, out_me):
[a, b, c, d] = input.shape
def d4_tensor_compare(self, input_, out_me):
[a, b, c, d] = input_.shape
for i in range(a):
for j in range(b):
for k in range(c):
for e in range(d):
if out_me[i, j, k, e] == 0:
assert True == True
assert True
else:
assert np.allclose(out_me[i, j, k, e], input[i, j, k, e] * (1 / 0.4), 0.0001, 0.0001)
assert np.allclose(out_me[i, j, k, e], input_[i, j, k, e] * (1 / 0.4), 0.0001, 0.0001)
def forward_mindspore_parallel_impl(self):
x = Tensor(self.input_np)

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
import numpy as np
import mindspore as ms
import mindspore.communication.management as distributedTool
@ -83,8 +83,8 @@ class Grad(Cell):
class MatmulAllgatherFactory:
def __init__(self, inputx_shape, inputy_shape, x_stra, y_stra):
self.inputx = self.GenValue(inputx_shape, 10)
self.inputy = self.GenValue(inputy_shape, 20)
self.inputx = self.gen_value(inputx_shape, 10)
self.inputy = self.gen_value(inputy_shape, 20)
self.x_stra = x_stra
self.y_stra = y_stra
stra_size = 1
@ -92,7 +92,7 @@ class MatmulAllgatherFactory:
stra_size = stra_size * s
self.stra_size = stra_size
def GenValue(self, input_shape, delta):
def gen_value(self, input_shape, delta):
size = 1
for s in input_shape:
size = size * s

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
import numpy as np
import mindspore as ms
import mindspore.communication.management as distributedTool
@ -87,9 +87,9 @@ class Grad(Cell):
class MatmulReduceFactory:
def __init__(self, inputx_shape, inputy_shape, inputz_shape, x_stra, y_stra, z_stra):
self.inputx = self.GenValue(inputx_shape, 10)
self.inputy = self.GenValue(inputy_shape, 20)
self.inputz = self.GenValue(inputz_shape, 30)
self.inputx = self.gen_value(inputx_shape, 10)
self.inputy = self.gen_value(inputy_shape, 20)
self.inputz = self.gen_value(inputz_shape, 30)
self.x_stra = x_stra
self.y_stra = y_stra
self.z_stra = z_stra
@ -98,7 +98,7 @@ class MatmulReduceFactory:
stra_size = stra_size * s
self.stra_size = stra_size
def GenValue(self, input_shape, delta):
def gen_value(self, input_shape, delta):
size = 1
for s in input_shape:
size = size * s

@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
import pytest
import numpy as np
import mindspore.communication.management as distributedTool
from mindspore import context

@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
import pytest
import numpy as np
import mindspore.communication.management as distributedTool
from mindspore import context

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
import numpy as np
from numpy import allclose
import mindspore.communication.management as distributedTool
@ -109,7 +109,7 @@ class BatchmatmulFactory:
inputb_shape[-1] = inputb_shape[-2]
inputb_shape[-2] = temp
if (len(inputa_shape) >= len(inputb_shape)):
if len(inputa_shape) >= len(inputb_shape):
out_shape = list(inputa_shape)
out_shape[-1] = inputb_shape[-1]
else:
@ -127,7 +127,7 @@ class BatchmatmulFactory:
strategy2[-1] = strategy2[-2]
strategy2[-2] = temp
if (len(strategy1) >= len(strategy2)):
if len(strategy1) >= len(strategy2):
out_strategy = strategy1.copy()
out_strategy[-1] = strategy2[-1]
else:
@ -189,13 +189,13 @@ class BatchmatmulFactory:
i += 1
return blocks
"""
shape每一维的上限2,4,8
"""
def id_to_list(self, id, shape):
def id_to_list(self, id_, shape):
"""
shape每一维的上限2,4,8
"""
result = []
r = id
r = id_
for i in range(0, len(shape)):
v = 1
for j in range(i + 1, len(shape)):

@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
import pytest
import numpy as np
import mindspore.communication.management as distributedTool
from mindspore import context

@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
import numpy as np
import pytest
from numpy import allclose
import mindspore as ms
import mindspore.communication.management as distributedTool

@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
import pytest
from numpy import allclose
import numpy as np
import mindspore as ms
import mindspore.communication.management as distributedTool

@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
import numpy as np
import pytest
from numpy import allclose
import mindspore as ms
import mindspore.communication.management as distributedTool
@ -48,6 +47,7 @@ class PReLU(Cell):
super(PReLU, self).__init__()
self.add = P.TensorAdd(strategy=strategy1_)
self.prelu = P.PReLU(strategy=strategy_)
self.channel = channel
def construct(self, x, z, w):
out = self.add(x, z)
@ -59,8 +59,8 @@ class Grad(Cell):
super(Grad, self).__init__()
self.network = network
def construct(self, input, z, w, output_grad):
return grad_all_with_sens(self.network)(input, z, w, output_grad)
def construct(self, input_, z, w, output_grad):
return grad_all_with_sens(self.network)(input_, z, w, output_grad)
class PReLUFactory:

@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
import pytest
import numpy as np
from numpy import allclose as allclose_nparray
import mindspore as ms
@ -129,9 +128,9 @@ class ReduceMeanFactory:
self.out_id = self.list_to_id(device_index, self.out_strategy)
print(self.out_id)
def id_to_list(self, id, shape):
def id_to_list(self, id_, shape):
result = []
r = id
r = id_
for i in range(0, len(shape)):
v = 1
for j in range(i + 1, len(shape)):

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
import numpy as np
import pytest
from numpy import allclose as allclose_nparray

@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
import pytest
import numpy as np
from numpy import allclose as allclose_nparray
import mindspore.communication.management as distributedTool
@ -118,9 +117,9 @@ class TransposeFactory:
i += 1
return blocks
def id_to_list(self, id, shape):
def id_to_list(self, id_, shape):
result = []
r = id
r = id_
for i in range(0, len(shape)):
v = 1
for j in range(i + 1, len(shape)):

@ -54,7 +54,7 @@ class Grad(nn.Cell):
return C.grad_all(self.network)(x, y)
def compile(net, x, y):
def compile_net(net, x, y):
net.set_auto_parallel()
_executor.compile(net, x, y)
@ -69,7 +69,7 @@ def test_add_relu_stride_slice():
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([128, 32]), dtype=ms.float32)
compile(net, x, y)
compile_net(net, x, y)
def test_add_relu_all_gather():
@ -82,4 +82,4 @@ def test_add_relu_all_gather():
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([128, 32]), dtype=ms.float32)
compile(net, x, y)
compile_net(net, x, y)

@ -17,7 +17,6 @@ import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor, context
from mindspore import context
from mindspore.common.api import _executor
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.optim.momentum import Momentum
@ -131,56 +130,56 @@ def test_allreduce_fusion_parameters():
cost_model_context.reset_cost_model_context()
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=2)
algorithm = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_algorithm')
assert (algorithm == 2)
assert algorithm == 2
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
algorithm = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_algorithm')
assert (algorithm == 1)
assert algorithm == 1
cost_model_context.reset_cost_model_context()
algorithm = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_algorithm')
assert (algorithm == 0)
assert algorithm == 0
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
fusion_times = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_times')
assert (fusion_times == 2)
assert fusion_times == 2
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.2)
tail_percent = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_tail_percent')
assert (tail_percent == 0.2)
assert tail_percent == 0.2
cost_model_context.reset_cost_model_context()
tail_percent = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_tail_percent')
assert (tail_percent == 0.1)
assert tail_percent == 0.1
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_time=0.2)
tail_time = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_tail_time')
assert (tail_time == 0.2)
assert tail_time == 0.2
cost_model_context.reset_cost_model_context()
tail_time = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_tail_time')
assert (tail_time == 0.1)
assert tail_time == 0.1
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_inherent_time=0.2)
allreduce_inherent_time = cost_model_context.get_cost_model_context(
'costmodel_allreduce_fusion_allreduce_inherent_time')
assert (allreduce_inherent_time == 0.2)
assert allreduce_inherent_time == 0.2
cost_model_context.reset_cost_model_context()
allreduce_inherent_time = cost_model_context.get_cost_model_context(
'costmodel_allreduce_fusion_allreduce_inherent_time')
assert (allreduce_inherent_time == 0.1)
assert allreduce_inherent_time == 0.1
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_allreduce_bandwidth=0.2)
allreduce_bandwidth = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_allreduce_bandwidth')
assert (allreduce_bandwidth == 0.2)
assert allreduce_bandwidth == 0.2
cost_model_context.reset_cost_model_context()
allreduce_bandwidth = cost_model_context.get_cost_model_context('costmodel_allreduce_fusion_allreduce_bandwidth')
assert (allreduce_bandwidth == 0.1)
assert allreduce_bandwidth == 0.1
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_computation_time_parameter=0.2)
computation_time_parameter = cost_model_context.get_cost_model_context(
'costmodel_allreduce_fusion_computation_time_parameter')
assert (computation_time_parameter == 0.2)
assert computation_time_parameter == 0.2
cost_model_context.reset_cost_model_context()
computation_time_parameter = cost_model_context.get_cost_model_context(
'costmodel_allreduce_fusion_computation_time_parameter')
assert (computation_time_parameter == 0.1)
assert computation_time_parameter == 0.1
def test_allreduce_fusion1():
@ -201,7 +200,7 @@ def test_allreduce_fusion1():
'backbone2.fc2.weight': 1,
'backbone2.fc1.weight': 1,
'backbone1.fc1.weight': 1}
assert (allreduce_fusion_dict == expect_dict)
assert allreduce_fusion_dict == expect_dict
cost_model_context.reset_cost_model_context()
@ -214,7 +213,7 @@ def test_allreduce_fusion2():
net = SimpleDMLNet(DenseNet1(has_bias=False, activation=None), DenseNet2(has_bias=False, activation=None))
allreduce_fusion_dict = train_common(net)
expect_dict = {}
assert (allreduce_fusion_dict == expect_dict)
assert allreduce_fusion_dict == expect_dict
cost_model_context.reset_cost_model_context()
@ -240,7 +239,7 @@ def test_allreduce_fusion3():
'backbone1.fc2.weight': 2,
'backbone1.fc1.bias': 2,
'backbone1.fc1.weight': 2}
assert (allreduce_fusion_dict == expect_dict)
assert allreduce_fusion_dict == expect_dict
cost_model_context.reset_cost_model_context()
@ -267,7 +266,7 @@ def test_allreduce_fusion4():
'backbone1.fc2.weight': 1,
'backbone1.fc1.weight': 1}
assert (allreduce_fusion_dict == expect_dict)
assert allreduce_fusion_dict == expect_dict
cost_model_context.reset_cost_model_context()
@ -295,7 +294,7 @@ def test_allreduce_fusion5():
'backbone1.fc4.weight': 2,
'backbone1.fc3.weight': 2,
'backbone1.fc2.weight': 1,
'backbone1.fc1.weight': 1, }
'backbone1.fc1.weight': 1,}
assert (allreduce_fusion_dict == expect_dict)
assert allreduce_fusion_dict == expect_dict
cost_model_context.reset_cost_model_context()

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save