!1402 for second order codes

Merge pull request !1402 from zongha/master
pull/1402/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit aeffccb7f8

@ -1,126 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning rate generator"""
import math
import numpy as np
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
"""linear_warmup_lr"""
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
lr = float(init_lr) + lr_inc * current_step
return lr
def cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0, num_periods=0.5):
"""linear_warmup_lr"""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
decay_steps = total_steps - warmup_steps
lr_each_step = []
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
# linear_decay = (total_steps - i) / decay_steps
cosine_decay = 0.5 * (1 + math.cos(math.pi * i / decay_steps))
decayed = cosine_decay
lr = base_lr * decayed
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0, num_periods=0.5):
"""warmup_cosine_annealing_lr"""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch * 0.99)
warmup_steps = int(warmup_epochs * steps_per_epoch)
decay_steps = total_steps - warmup_steps
lr_each_step = []
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
linear_decay = (total_steps - i) / decay_steps
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * num_periods * i / decay_steps))
decayed = linear_decay * cosine_decay
lr = base_lr * decayed + 0.000005
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode):
"""
generate learning rate array
Args:
global_step(int): total steps of the training
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_epochs(int): number of warmup epochs
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
lr_decay_mode(string): learning rate decay mode, including steps, poly or default
Returns:
np.array, learning rate array
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
if lr_decay_mode == 'steps':
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
for i in range(total_steps):
if i < decay_epoch_index[0]:
lr = lr_max
elif i < decay_epoch_index[1]:
lr = lr_max * 0.1
elif i < decay_epoch_index[2]:
lr = lr_max * 0.01
else:
lr = lr_max * 0.001
lr_each_step.append(lr)
elif lr_decay_mode == 'poly':
if warmup_steps != 0:
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
else:
inc_each_step = 0
for i in range(total_steps):
if i < warmup_steps:
lr = float(lr_init) + inc_each_step * float(i)
else:
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
lr = float(lr_max) * base * base
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)
else:
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)
lr_each_step.append(lr)
current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[current_step:]
return learning_rate

@ -13,12 +13,10 @@
# limitations under the License.
# ============================================================================
"""Dataset help for minddata dataset"""
from mindspore import context
from mindspore._checkparam import check_bool
from mindspore.nn.wrap import GetNextSingleOp
from mindspore.parallel._utils import _get_device_num, _get_global_rank, _get_parallel_mode
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, _to_tensor, \
_construct_tensor_list, _to_full_shapes, _to_full_tensor
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \
_to_full_shapes
from mindspore.train.parallel_utils import ParallelMode
@ -42,19 +40,9 @@ class DatasetHelper:
>>> outputs = network(*inputs)
"""
def __init__(self, dataset, first_order_iter=0, dataset_sink_mode=True):
def __init__(self, dataset, dataset_sink_mode=True, iter_first_order=0):
check_bool(dataset_sink_mode)
iterclass = _DatasetIterGE
if not dataset_sink_mode:
iterclass = _DatasetIterFeed
elif not context.get_context("enable_ge"):
if context.get_context("enable_loop_sink"):
iterclass = _DatasetIterMSLoopSink
else:
iterclass = _DatasetIterMS
self.iter = iterclass(dataset, first_order_iter)
self.iter = _DatasetIterMSLoopSink(dataset, iter_first_order)
def __iter__(self):
return self.iter.__iter__()
@ -85,12 +73,6 @@ class _DatasetIter:
self.dataset = dataset
dataset_types, dataset_shapes = _get_types_and_shapes(dataset)
self.dataset_types, self.dataset_shapes = dataset_types, dataset_shapes
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, use a complete tensor to
# compile, and slice tensor to run. The batch dimension of tensors for compile is device_number
# times the batch dimension of tensors for run
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
device_num = _get_device_num()
self.dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
def __iter__(self):
self.ind = 0
@ -109,83 +91,28 @@ class _DatasetIter:
loop_count = 1
if hasattr(dataset, '__loop_size__'):
loop_size = dataset.__loop_size__
if dataset.get_dataset_size() % loop_size != 0:
raise ValueError(f'Dataset size {dataset.get_dataset_size()} and '
f'loop_size {loop_size} are not matched.')
loop_count = int(dataset.get_dataset_size() / loop_size)
return loop_count
class _DatasetIterMSLoopSink(_DatasetIter):
"""Iter for context (enable_loop_sink=True)"""
"""Iter for context (device_target=Ascend)"""
def __init__(self, dataset, first_order_iter):
def __init__(self, dataset, iter_first_order):
super(_DatasetIterMSLoopSink, self).__init__(dataset)
# self.loop_count = self.get_loop_count(dataset)
loop_size = dataset.__loop_size__ + first_order_iter
loop_size = dataset.__loop_size__ + iter_first_order
self.loop_count = int(dataset.get_dataset_size() / loop_size) * 2
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, use a complete tensor to
# compile, and slice tensor to run. The batch dimension of tensors for compile is device_number
# times the batch dimension of tensors for run. Now only support LoopSink.
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
device_num = _get_device_num()
self.dataset_shapes = _to_full_shapes(self.dataset_shapes, device_num)
def op():
return tuple()
self.op = op
class _DatasetIterMS(_DatasetIter):
"""Iter for context (enable_loop_sink=False)"""
def __init__(self, dataset, first_order_order):
super(_DatasetIterMS, self).__init__(dataset)
self.loop_count = dataset.get_dataset_size()
self.loop_size = 1
queue_name = dataset.__ME_INITED__
self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name)
class _DatasetIterGE(_DatasetIter):
"""Iter for ge"""
def __init__(self, dataset):
super(_DatasetIterGE, self).__init__(dataset)
self.loop_count = self.get_loop_count(dataset)
parallel_mode = _get_parallel_mode()
self.need_to_full = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
batch_expand_num = 1
if self.need_to_full:
batch_expand_num = _get_device_num()
tensor_list_run = _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num)
def op():
return tensor_list_run
self.op = op
class _DatasetIterFeed:
"""Iter for feed data"""
def __init__(self, dataset, first_order_order):
self.dataset = dataset
self.device_num = _get_device_num()
self.global_rank = _get_global_rank()
self.repeat_count = dataset.get_repeat_count()
self.repeat_ind = 0
self.loop_count = dataset.get_dataset_size()
self.ind = 0
parallel_mode = context.get_auto_parallel_context("parallel_mode")
self.need_to_full = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
def __iter__(self):
if self.repeat_ind % self.repeat_count == 0:
self.iter = self.dataset.__iter__()
self.repeat_ind += 1
self.ind = 0
return self
def __next__(self):
if self.ind >= self.loop_count:
raise StopIteration()
self.ind += 1
data = self.iter.__next__()
if self.need_to_full:
return _to_full_tensor(data, self.device_num, self.global_rank)
return _to_tensor(data)

File diff suppressed because it is too large Load Diff

@ -14,22 +14,24 @@
# ============================================================================
"""ResNet."""
import math
import mindspore.nn as nn
import numpy as np
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from second_order.thor_layer import Conv2d_Thor, Dense_Thor
from model.thor_layer import Conv2d_Thor, Dense_Thor
def calculate_gain(nonlinearity, param=None):
"""calculate_gain"""
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
res = 0
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
return 1
res = 1
elif nonlinearity == 'tanh':
return 5.0 / 3
res = 5.0 / 3
elif nonlinearity == 'relu':
return math.sqrt(2.0)
res = math.sqrt(2.0)
elif nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
@ -38,16 +40,17 @@ def calculate_gain(nonlinearity, param=None):
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
return math.sqrt(2.0 / (1 + negative_slope ** 2))
res = math.sqrt(2.0 / (1 + negative_slope ** 2))
else:
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
return res
def _calculate_fan_in_and_fan_out(tensor):
"""_calculate_fan_in_and_fan_out"""
dimensions = len(tensor)
if dimensions < 2:
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
if dimensions == 2: # Linear
fan_in = tensor[1]
fan_out = tensor[0]
@ -67,7 +70,6 @@ def _calculate_correct_fan(tensor, mode):
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
return fan_in if mode == 'fan_in' else fan_out
@ -93,8 +95,6 @@ def _conv3x3(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, freq
return Conv2d_Thor(in_channel, out_channel,
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency)
# return nn.Conv2d(in_channel, out_channel,
# kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight)
def _conv1x1(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278):
@ -125,7 +125,7 @@ def _bn_last(channel):
def _fc(in_channel, out_channel, damping, loss_scale, frequency):
weight_shape = (out_channel, in_channel)
weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5))
weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5)))
return Dense_Thor(in_channel, out_channel, has_bias=False, weight_init=weight,
bias_init=0, damping=damping, loss_scale=loss_scale, frequency=frequency)
@ -133,15 +133,15 @@ def _fc(in_channel, out_channel, damping, loss_scale, frequency):
class ResidualBlock(nn.Cell):
"""
ResNet V1 residual block definition.
Args:
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.
Returns:
Tensor, output tensor.
Examples:
>>> ResidualBlock(3, 256, stride=2)
"""
@ -210,7 +210,7 @@ class ResidualBlock(nn.Cell):
class ResNet(nn.Cell):
"""
ResNet architecture.
Args:
block (Cell): Block for network.
layer_nums (list): Numbers of block in different layers.
@ -220,7 +220,7 @@ class ResNet(nn.Cell):
num_classes (int): The number of classes that the training images are belonging to.
Returns:
Tensor, output tensor.
Examples:
>>> ResNet(ResidualBlock,
>>> [3, 4, 6, 3],
@ -290,17 +290,17 @@ class ResNet(nn.Cell):
damping, loss_scale, frequency):
"""
Make stage network of ResNet.
Args:
block (Cell): Resnet block.
layer_num (int): Layer number.
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer.
Returns:
SequentialCell, the output layer.
Examples:
>>> _make_layer(ResidualBlock, 3, 128, 256, 2)
"""
@ -321,7 +321,7 @@ class ResNet(nn.Cell):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
c1, argmax = self.maxpool(x)
c1, _ = self.maxpool(x)
c2 = self.layer1(c1)
c3 = self.layer2(c2)
@ -338,13 +338,13 @@ class ResNet(nn.Cell):
def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278):
"""
Get ResNet50 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of ResNet50 neural network.
Examples:
>>> net = resnet50(10)
"""

@ -51,6 +51,6 @@ do
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
python train_0517_1.py --do_train=True --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$2 > log 2>&1 &
python train.py --do_train=True --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$2 > log 2>&1 &
cd ..
done

@ -17,7 +17,6 @@ import argparse
import os
import random
import mindspore.dataset.engine as de
from mindspore import Tensor
from mindspore import context
from mindspore.communication.management import init
@ -25,19 +24,17 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.model import ParallelMode
from second_order.model_second_order import Model
from second_order.resnet import resnet50
from second_order.thor import THOR
from model.model_thor import Model
from model.resnet import resnet50
from model.thor import THOR
import numpy as np
from config_imagenet import config
from config import config
from crossentropy import CrossEntropy
from dataset_imagenet import create_dataset
from lr_generator import warmup_cosine_annealing_lr
random.seed(1)
np.random.seed(1)
de.config.set_seed(1)
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute')
@ -50,29 +47,29 @@ args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=device_id)
context.set_context(enable_task_sink=True)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
def get_second_order_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch):
"""get_second_order_lr"""
def get_model_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch):
"""get_model_lr"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
for i in range(total_steps):
epoch = (i + 1) / steps_per_epoch
base = (1.0 - float(epoch) / total_epochs) ** decay
lr_local = lr_init * base
if epoch >= 39:
lr_local = lr_local * 0.5
if epoch >= 40:
lr_local = lr_local * 0.5
lr_each_step.append(lr_local)
current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32)
print("learning_rate_is=====", lr_each_step)
learning_rate = lr_each_step[current_step:]
return learning_rate
def get_second_order_damping(global_step, damping_init, decay_rate, total_epochs, steps_per_epoch):
"""get_second_order_damping"""
def get_model_damping(global_step, damping_init, decay_rate, total_epochs, steps_per_epoch):
"""get_model_damping"""
damping_each_step = []
total_steps = steps_per_epoch * total_epochs
for step in range(total_steps):
@ -83,26 +80,23 @@ def get_second_order_damping(global_step, damping_init, decay_rate, total_epochs
current_step = global_step
damping_each_step = np.array(damping_each_step).astype(np.float32)
damping_now = damping_each_step[current_step:]
print("damping_is=========", damping_now)
return damping_now
if __name__ == '__main__':
if args_opt.do_eval:
print("eval")
else:
if args_opt.run_distribute:
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True, parameter_broadcast=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([80], "hccl_world_groupsum1")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum4")
init()
else:
print(" ")
if not args_opt.do_eval and args_opt.run_distribute:
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True, parameter_broadcast=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([107], "hccl_world_groupsum1")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum2")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum4")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum5")
init()
epoch_size = config.epoch_size
damping = get_second_order_damping(0, 0.03, 0.87, 50, 5004)
damping = get_model_damping(0, 0.03, 0.87, 50, 5004)
net = resnet50(class_num=config.class_num, damping=damping, loss_scale=config.loss_scale,
frequency=config.frequency)
@ -115,17 +109,12 @@ if __name__ == '__main__':
step_size = dataset.get_dataset_size()
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
lr = Tensor(warmup_cosine_annealing_lr(0.035,
step_size,
config.warmup_epochs,
50,
config.T_max,
config.eta_min))
opt = THOR(filter(lambda x: x.requires_grad, net.get_parameters()), lr,
config.momentum, damping, config.frequency,
lr = Tensor(get_model_lr(0, 0.05, 6, 70, 5004))
opt = THOR(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
filter(lambda x: 'matrix_A' in x.name, net.get_parameters()),
filter(lambda x: 'matrix_G' in x.name, net.get_parameters()),
filter(lambda x: 'spatial_norm' in x.name, net.get_parameters()),
filter(lambda x: 'A_inv_max' in x.name, net.get_parameters()),
filter(lambda x: 'G_inv_max' in x.name, net.get_parameters()),
config.weight_decay, config.loss_scale)
model = Model(net, loss_fn=loss, optimizer=opt, amp_level='O2', loss_scale_manager=loss_scale,

@ -0,0 +1,76 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""batch_matmul_impl"""
from mindspore.ops.op_info_register import op_info_register
@op_info_register("""{
"op_name": "CusBatchMatMul",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "batchmatmul.so",
"compute_cost": 10,
"kernel_name": "CusBatchMatMul",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float32"
],
"format": [
"DefaultFormat"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float32"
],
"format": [
"DefaultFormat"
],
"name": "x2",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32"
],
"format": [
"DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=True, kernel_name="batchmatmul"):
"""CusBatchMatMul"""
return

@ -0,0 +1,64 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CusCholeskyTrsm"""
from mindspore.ops.op_info_register import op_info_register
@op_info_register("""{
"op_name": "CusCholeskyTrsm",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "choleskytrsm.so",
"compute_cost": 10,
"kernel_name": "CusCholeskyTrsm",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float32"
],
"format": [
"DefaultFormat"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32"
],
"format": [
"DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
def CusCholeskyTrsm(input_x, output, kernel_name):
"""CusCholeskyTrsm"""
return

@ -0,0 +1,69 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CusFusedAbsMax1"""
from mindspore.ops.op_info_register import op_info_register
@op_info_register("""{
"op_name": "CusFusedAbsMax1",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "fusedabsmax1.so",
"compute_cost": 10,
"kernel_name": "CusFusedAbsMax1",
"partial_flag": true,
"attr": [
{
"name": "origin_shape",
"param_type": "required",
"type": "listInt",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float32"
],
"format": [
"DefaultFormat"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32"
],
"format": [
"DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_max1"):
"""CusFusedAbsMax1"""
return

@ -0,0 +1,87 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CusImg2ColNC1HWC0"""
from mindspore.ops.op_info_register import op_info_register
@op_info_register("""{
"op_name": "CusImg2ColNC1HWC0",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "img2colnc1hwc0.so",
"compute_cost": 10,
"kernel_name": "CusImg2ColNC1HWC0",
"partial_flag": true,
"attr": [
{
"name": "ksizes",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "strides",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "dilates",
"param_type": "required",
"type": "listInt",
"value": "all"
},
{
"name": "padding",
"param_type": "required",
"type": "str",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"NC1HWC0"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"FRACTAL_NZ"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
def CusImg2ColNC1HWC0(input_x, output, ksizes, strides, dilates, padding, kernel_name="img2col"):
"""CusImg2ColNC1HWC0"""
return

@ -0,0 +1,101 @@
# -*- coding:utf-8 -*-
"""
copyright 2020 Huawei Technologies Co., Ltd
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License == distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
matmul
"""
from __future__ import absolute_import
from mindspore.ops.op_info_register import op_info_register
from topi.cce import util
# General limitation of the size for input shape: 2**31
SHAPE_SIZE_LIMIT = 2147483648
NoneType = type(None)
@op_info_register("""{
"op_name": "CusMatMulCubeDenseLeft",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "matmulcubedenseleft.so",
"compute_cost": 10,
"kernel_name": "CusMatMulCubeDenseLeft",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"DefaultFormat"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16"
],
"format": [
"FRACTAL_NZ"
],
"name": "x2",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16"
],
"format": [
"DefaultFormat"
],
"name": "x3",
"need_compile": false,
"param_type": "optional",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"FRACTAL_NZ"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
@util.check_input_type(dict, dict, (dict, NoneType), dict, bool, bool, str)
def CusMatMulCubeDenseLeft(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False,
kernel_name="matmulcube"):
"""CusMatMulCubeDenseLeft"""
return

@ -0,0 +1,102 @@
# -*- coding:utf-8 -*-
"""
copyright 2020 Huawei Technologies Co., Ltd
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License == distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
matmul
"""
from __future__ import absolute_import
from mindspore.ops.op_info_register import op_info_register
from topi.cce import util
# General limitation of the size for input shape: 2**31
SHAPE_SIZE_LIMIT = 2147483648
NoneType = type(None)
@op_info_register("""{
"op_name": "CusMatMulCubeFraczLeftCast",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "matmulcubefraczleftcast.so",
"compute_cost": 10,
"kernel_name": "CusMatMulCubeFraczLeftCast",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"DefaultFormat"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float32"
],
"format": [
"FracZ"
],
"name": "x2",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16"
],
"format": [
"DefaultFormat"
],
"name": "x3",
"need_compile": false,
"param_type": "optional",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"FracZ"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
# pylint: disable=locally-disabled,too-many-arguments, too-many-locals, too-many-statements
@util.check_input_type(dict, dict, (dict, NoneType), dict, bool, bool, str)
def CusMatMulCubeFraczLeftCast(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False,
kernel_name="CusMatMulCubeFraczLeftCast"):
"""CusMatMulCubeFraczLeftCast"""
return

@ -0,0 +1,113 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
copyright 2020 Huawei Technologies Co., Ltd
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License == distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
matmul
"""
from __future__ import absolute_import
from mindspore.ops.op_info_register import op_info_register
# General limitation of the size for input shape: 2**31
SHAPE_SIZE_LIMIT = 2147483648
NoneType = type(None)
@op_info_register("""{
"op_name": "CusMatMulCubeFraczRightMul",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "matmulcubefraczrightmul.so",
"compute_cost": 10,
"kernel_name": "CusMatMulCubeFraczRightMul",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"FracZ"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16"
],
"format": [
"DefaultFormat"
],
"name": "x2",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float32"
],
"format": [
"DefaultFormat"
],
"name": "x3",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 3,
"dtype": [
"float16"
],
"format": [
"DefaultFormat"
],
"name": "x4",
"need_compile": false,
"param_type": "optional",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32"
],
"format": [
"FracZ"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
def CusMatMulCubeFraczRightMul(input_x1, input_x2, input_x3, bias=None, output_y={}, trans_a=False, trans_b=False,
kernel_name="matmulcube"):
"""CusMatMulCubeFraczRightMul"""
return

@ -0,0 +1,114 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
copyright 2020 Huawei Technologies Co., Ltd
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License == distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
matmul
"""
from __future__ import absolute_import
from mindspore.ops.op_info_register import op_info_register
from topi.cce import util
# General limitation of the size for input shape: 2**31
SHAPE_SIZE_LIMIT = 2147483648
NoneType = type(None)
@op_info_register("""{
"op_name": "CusMatMulCube",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "matmulcube.so",
"compute_cost": 10,
"kernel_name": "CusMatMulCube",
"partial_flag": true,
"attr": [
{
"name": "transpose_a",
"param_type": "required",
"type": "bool",
"value": "all"
},
{
"name": "transpose_b",
"param_type": "required",
"type": "bool",
"value": "all"
}
],
"inputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"FRACTAL_NZ"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 1,
"dtype": [
"float16"
],
"format": [
"FRACTAL_NZ"
],
"name": "x2",
"need_compile": false,
"param_type": "required",
"shape": "all"
},
{
"index": 2,
"dtype": [
"float16"
],
"format": [
"DefaultFormat"
],
"name": "x3",
"need_compile": false,
"param_type": "optional",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32"
],
"format": [
"FRACTAL_NZ"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
# pylint: disable=locally-disabled,too-many-arguments, too-many-locals, too-many-statements
@util.check_input_type(dict, dict, (dict, NoneType), dict, bool, bool, str)
def CusMatMulCube(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"):
"""CusMatMulCube"""
return

@ -0,0 +1,63 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CusMatrixCombine"""
from mindspore.ops.op_info_register import op_info_register
@op_info_register("""{
"op_name": "CusMatrixCombine",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "matrixcombine.so",
"compute_cost": 10,
"kernel_name": "CusMatrixCombine",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float32"
],
"format": [
"DefaultFormat"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32"
],
"format": [
"DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
def CusMatrixCombine(input_x, output, kernel_name="matrix_combine"):
"""CusMatrixCombine"""
return

@ -0,0 +1,63 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CusTranspose02314"""
from mindspore.ops.op_info_register import op_info_register
@op_info_register("""{
"op_name": "CusTranspose02314",
"imply_type": "TBE",
"fusion_type": "OPAQUE",
"async_flag": false,
"binfile_name": "transpose02314.so",
"compute_cost": 10,
"kernel_name": "CusTranspose02314",
"partial_flag": true,
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"NC1HWC0"
],
"name": "x1",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float16"
],
"format": [
"DefaultFormat"
],
"name": "y",
"need_compile": false,
"param_type": "required",
"shape": "all"
}
]
}""")
def CusTranspose02314(input_x, output, kernel_name="transpose021354"):
"""CusTranspose02314"""
return

@ -0,0 +1,248 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""thor_ops"""
import mindspore as ms
from mindspore.ops import prim_attr_register, PrimitiveWithInfer
from mindspore.ops.composite import multitype_ops as C
class CusBatchMatMul(PrimitiveWithInfer):
"""CusMatMulCube definition"""
@prim_attr_register
def __init__(self):
"""init CusMatMulCube"""
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y'])
def get_bprop(self):
def bprop(x1, x2, out, dout):
return (C.zeros_like(x1), C.zeros_like(x2))
return bprop
def infer_shape(self, data1_shape, data2_shape):
return data1_shape
def infer_dtype(self, data1_dtype, data2_dtype):
return data1_dtype
class CusCholeskyTrsm(PrimitiveWithInfer):
"""CusCholeskyTrsm definition"""
@prim_attr_register
def __init__(self):
"""init CusCholeskyTrsm"""
self.init_prim_io_names(inputs=['x1'], outputs=['y'])
def infer_shape(self, data1_shape):
ll = []
m, _ = data1_shape
if m >= 128:
ll = [m // 128, 128, 128]
else:
ll = [1, 64, 64]
return ll
def infer_dtype(self, data1_dtype):
return data1_dtype
class CusFusedAbsMax1(PrimitiveWithInfer):
"""CusCholeskyTrsm definition"""
@prim_attr_register
def __init__(self, origin_shape=[-1, -1]):
"""init CusCholeskyTrsm"""
self.init_prim_io_names(inputs=['x1'], outputs=['y'])
self.origin_shape = origin_shape
def get_bprop(self):
def bprop(x, out, dout):
return (C.zeros_like(x),)
return bprop
def infer_shape(self, data1_shape):
ll = []
if len(data1_shape) == 2:
ll = [1,]
else:
ll = [32, 64]
return ll
def infer_dtype(self, data1_dtype):
return data1_dtype
class CusImg2Col(PrimitiveWithInfer):
"""CusImg2Col definition"""
@prim_attr_register
def __init__(self, ksizes, strides, dilates=(1, 1, 1, 1), mode="NC1HWC0"):
"""init CusImg2Col"""
self.init_prim_io_names(inputs=['x1'], outputs=['y'])
self.ksizes = ksizes
self.strides = strides
self.dilates = dilates
self.mode = mode
def get_bprop(self):
def bprop(x, out, dout):
return (C.zeros_like(x),)
return bprop
def infer_shape(self, data1_shape):
bs, c, h, w = data1_shape
_, stride_h, stride_w, _ = self.strides
_, k_w, k_h, _ = self.ksizes
# assert m == n
c0 = 16
c1 = c // 16
if c1 == 0:
c1 = 1
shape = [bs * int(h // stride_h) * int(w // stride_w), k_w * k_h * c1 * c0]
return shape
def infer_dtype(self, data1_dtype):
return data1_dtype
class CusMatMulCubeDenseLeft(PrimitiveWithInfer):
"""CusMatMulCube definition"""
@prim_attr_register
def __init__(self):
"""init CusMatMulCube"""
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y'])
def get_bprop(self):
def bprop(x1, x2, out, dout):
return (C.zeros_like(x1), C.zeros_like(x2))
return bprop
def infer_shape(self, data1_shape, data2_shape):
return data2_shape
def infer_dtype(self, data1_dtype, data2_dtype):
return ms.common.dtype.tensor_type(getattr(ms, "float16"))
class CusMatMulCubeFraczRightMul(PrimitiveWithInfer):
"""CusMatMulCubeFraczRightMul definition"""
@prim_attr_register
def __init__(self):
"""init CusMatMulCubeFraczRightMul"""
self.init_prim_io_names(inputs=['x1', 'x2', 'x3'], outputs=['y'])
def get_bprop(self):
def bprop(x1, x2, x3, out, dout):
return (C.zeros_like(x1), C.zeros_like(x2), C.zeros_like(x3))
return bprop
def infer_shape(self, data1_shape, data2_shape, data3_shape):
return data1_shape
def infer_dtype(self, data1_dtype, data2_dtype, data3_dtype):
return ms.common.dtype.tensor_type(getattr(ms, "float32"))
class CusMatMulCube(PrimitiveWithInfer):
"""CusMatMulCube definition"""
@prim_attr_register
def __init__(self, transpose_a=False, transpose_b=False):
"""init CusMatMulCube"""
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y'])
self.transpose_a = transpose_a
self.transpose_b = transpose_b
def get_bprop(self):
def bprop(x1, x2, out, dout):
return (C.zeros_like(x1), C.zeros_like(x2))
return bprop
def infer_shape(self, data1_shape, data2_shape):
# shape = [1, data1_shape[1], data2_shape[2], 16, 16]
# return shape
if self.transpose_a:
k1, m = data1_shape
else:
m, k1 = data1_shape
if self.transpose_b:
n, k2 = data2_shape
else:
k2, n = data2_shape
assert k1 == k2
shape = [m, n]
return shape
def infer_dtype(self, data1_dtype, data2_dtype):
return ms.common.dtype.tensor_type(getattr(ms, "float32"))
class CusMatrixCombine(PrimitiveWithInfer):
"""CusMatMulCube definition"""
@prim_attr_register
def __init__(self):
"""init CusMatMulCube"""
self.init_prim_io_names(inputs=['x'], outputs=['y'])
def get_bprop(self):
def bprop(x, out, dout):
return (C.zeros_like(x),)
return bprop
def infer_shape(self, data_shape):
a, b, c = data_shape
shape = [a * b, a * c]
return shape
def infer_dtype(self, data_dtype):
return data_dtype
class CusTranspose02314(PrimitiveWithInfer):
"""CusTranspose02314 definition"""
@prim_attr_register
def __init__(self):
"""init CusTranspose02314"""
self.init_prim_io_names(inputs=['x1'], outputs=['y'])
def get_bprop(self):
def bprop(x, out, dout):
return (C.zeros_like(x),)
return bprop
def infer_shape(self, data1_shape):
assert len(data1_shape) == 4
n, c, h, w = data1_shape
c0 = 16
c1 = c // 16
shape = (n * h * w, c1 * c0)
return shape
def infer_dtype(self, data1_dtype):
return data1_dtype
Loading…
Cancel
Save