!6639 add monilenetv2_quant and resnet50_quant st

Merge pull request !6639 from hwjiaorui/master
pull/6639/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 3e885c0bc1

@ -0,0 +1,67 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" create train dataset. """
from functools import partial
import mindspore.dataset as ds
import mindspore.common.dtype as mstype
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
def create_dataset(dataset_path, config, repeat_num=1, batch_size=32):
"""
create a train dataset
Args:
dataset_path(string): the path of dataset.
config(EasyDict)the basic config for training
repeat_num(int): the repeat times of dataset. Default: 1.
batch_size(int): the batch size of dataset. Default: 32.
Returns:
dataset
"""
load_func = partial(ds.Cifar10Dataset, dataset_path)
cifar_ds = load_func(num_parallel_workers=8, shuffle=False)
resize_height = config.image_height
resize_width = config.image_width
rescale = 1.0 / 255.0
shift = 0.0
# define map operations
# interpolation default BILINEAR
resize_op = C.Resize((resize_height, resize_width))
rescale_op = C.Rescale(rescale, shift)
normalize_op = C.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
changeswap_op = C.HWC2CHW()
type_cast_op = C2.TypeCast(mstype.int32)
c_trans = [resize_op, rescale_op, normalize_op, changeswap_op]
# apply map operations on images
cifar_ds = cifar_ds.map(input_columns="label", operations=type_cast_op)
cifar_ds = cifar_ds.map(input_columns="image", operations=c_trans)
# apply batch operations
cifar_ds = cifar_ds.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
cifar_ds = cifar_ds.repeat(repeat_num)
return cifar_ds

@ -0,0 +1,56 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning rate generator"""
import math
import numpy as np
def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
"""
generate learning rate array
Args:
global_step(int): total steps of the training
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_epochs(int): number of warmup epochs
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
Returns:
np.array, learning rate array
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
lr = lr_end + \
(lr_max - lr_end) * \
(1. + math.cos(math.pi * (i - warmup_steps) /
(total_steps - warmup_steps))) / 2.
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)
current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[current_step:]
return learning_rate

File diff suppressed because it is too large Load Diff

@ -0,0 +1,123 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Train Mobilenetv2_quant on Cifar10"""
import pytest
import numpy as np
from easydict import EasyDict as ed
from mindspore import context
from mindspore import Tensor
from mindspore import nn
from mindspore.train.model import Model
from mindspore.train.quant import quant
from mindspore.common import set_seed
from dataset import create_dataset
from lr_generator import get_lr
from utils import Monitor, CrossEntropyWithLabelSmooth
from mobilenetV2 import mobilenetV2
config_ascend_quant = ed({
"num_classes": 10,
"image_height": 224,
"image_width": 224,
"batch_size": 200,
"step_threshold": 10,
"data_load_mode": "mindata",
"epoch_size": 1,
"start_epoch": 200,
"warmup_epochs": 1,
"lr": 0.3,
"momentum": 0.9,
"weight_decay": 4e-5,
"label_smooth": 0.1,
"loss_scale": 1024,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 300,
"save_checkpoint_path": "./checkpoint",
"quantization_aware": True,
})
dataset_path = "/dataset/workspace/mindspore_dataset/cifar-10-batches-bin/"
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def train_on_ascend():
set_seed(1)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
config = config_ascend_quant
print("training configure: {}".format(config))
epoch_size = config.epoch_size
# define network
network = mobilenetV2(num_classes=config.num_classes)
# define loss
if config.label_smooth > 0:
loss = CrossEntropyWithLabelSmooth(
smooth_factor=config.label_smooth, num_classes=config.num_classes)
else:
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# define dataset
dataset = create_dataset(dataset_path=dataset_path,
config=config,
repeat_num=1,
batch_size=config.batch_size)
step_size = dataset.get_dataset_size()
# convert fusion network to quantization aware network
network = quant.convert_quant_network(network,
bn_fold=True,
per_channel=[True, False],
symmetric=[True, False])
# get learning rate
lr = Tensor(get_lr(global_step=config.start_epoch * step_size,
lr_init=0,
lr_end=0,
lr_max=config.lr,
warmup_epochs=config.warmup_epochs,
total_epochs=epoch_size + config.start_epoch,
steps_per_epoch=step_size))
# define optimization
opt = nn.Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), lr, config.momentum,
config.weight_decay)
# define model
model = Model(network, loss_fn=loss, optimizer=opt)
print("============== Starting Training ==============")
monitor = Monitor(lr_init=lr.asnumpy(),
step_threshold=config.step_threshold)
callback = [monitor]
model.train(epoch_size, dataset, callbacks=callback,
dataset_sink_mode=False)
print("============== End Training ==============")
expect_avg_step_loss = 2.32
avg_step_loss = np.mean(np.array(monitor.losses))
print("average step loss:{}".format(avg_step_loss))
assert avg_step_loss < expect_avg_step_loss
if __name__ == '__main__':
train_on_ascend()

@ -0,0 +1,118 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MobileNetV2 utils"""
import time
import numpy as np
from mindspore.train.callback import Callback
from mindspore import Tensor
from mindspore import nn
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
class Monitor(Callback):
"""
Monitor loss and time.
Args:
lr_init (numpy array): train lr
Returns:
None
Examples:
>>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy())
"""
def __init__(self, lr_init=None, step_threshold=10):
super(Monitor, self).__init__()
self.lr_init = lr_init
self.lr_init_len = len(lr_init)
self.step_threshold = step_threshold
def epoch_begin(self, run_context):
self.losses = []
self.epoch_time = time.time()
def epoch_end(self, run_context):
cb_params = run_context.original_args()
epoch_mseconds = (time.time() - self.epoch_time) * 1000
per_step_mseconds = epoch_mseconds / cb_params.batch_num
print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:8.6f}".format(epoch_mseconds,
per_step_mseconds,
np.mean(self.losses)))
self.epoch_mseconds = epoch_mseconds
def step_begin(self, run_context):
self.step_time = time.time()
def step_end(self, run_context):
cb_params = run_context.original_args()
step_mseconds = (time.time() - self.step_time) * 1000
step_loss = cb_params.net_outputs
if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor):
step_loss = step_loss[0]
if isinstance(step_loss, Tensor):
step_loss = np.mean(step_loss.asnumpy())
self.losses.append(step_loss)
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num
print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:8.6f}/{:5.3f}], time:[{:5.3f}], lr:[{:5.5f}]".format(
cb_params.cur_epoch_num, cb_params.epoch_num, cur_step_in_epoch +
1, cb_params.batch_num, step_loss,
np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1]))
if cb_params.cur_step_num == self.step_threshold:
run_context.request_stop()
class CrossEntropyWithLabelSmooth(_Loss):
"""
CrossEntropyWith LabelSmooth.
Args:
smooth_factor (float): smooth factor, default=0.
num_classes (int): num classes
Returns:
None.
Examples:
>>> CrossEntropyWithLabelSmooth(smooth_factor=0., num_classes=1000)
"""
def __init__(self, smooth_factor=0., num_classes=1000):
super(CrossEntropyWithLabelSmooth, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor /
(num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
self.cast = P.Cast()
def construct(self, logit, label):
one_hot_label = self.onehot(self.cast(label, mstype.int32), F.shape(logit)[1],
self.on_value, self.off_value)
out_loss = self.ce(logit, one_hot_label)
out_loss = self.mean(out_loss, 0)
return out_loss

@ -0,0 +1,68 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" create train dataset. """
from functools import partial
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.c_transforms as C2
import mindspore.dataset.vision.c_transforms as C
def create_dataset(dataset_path, config, repeat_num=1, batch_size=32):
"""
create a train dataset
Args:
dataset_path(string): the path of dataset.
config(EasyDict)the basic config for training
repeat_num(int): the repeat times of dataset. Default: 1.
batch_size(int): the batch size of dataset. Default: 32.
Returns:
dataset
"""
load_func = partial(de.Cifar10Dataset, dataset_path)
ds = load_func(num_parallel_workers=8, shuffle=False)
resize_height = config.image_height
resize_width = config.image_width
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
# define map operations
resize_op = C.Resize((resize_height, resize_width))
normalize_op = C.Normalize(mean=mean, std=std)
changeswap_op = C.HWC2CHW()
c_trans = [resize_op, normalize_op, changeswap_op]
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(operations=c_trans, input_columns="image",
num_parallel_workers=8)
ds = ds.map(operations=type_cast_op,
input_columns="label", num_parallel_workers=8)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
ds = ds.repeat(repeat_num)
return ds

@ -0,0 +1,93 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning rate generator"""
import math
import numpy as np
def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode):
"""
generate learning rate array
Args:
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_epochs(int): number of warmup epochs
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
lr_decay_mode(string): learning rate decay mode, including steps, poly, cosine or default
Returns:
np.array, learning rate array
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
if lr_decay_mode == 'steps':
decay_epoch_index = [0.3 * total_steps,
0.6 * total_steps, 0.8 * total_steps]
for i in range(total_steps):
if i < decay_epoch_index[0]:
lr = lr_max
elif i < decay_epoch_index[1]:
lr = lr_max * 0.1
elif i < decay_epoch_index[2]:
lr = lr_max * 0.01
else:
lr = lr_max * 0.001
lr_each_step.append(lr)
elif lr_decay_mode == 'poly':
if warmup_steps != 0:
inc_each_step = (float(lr_max) - float(lr_init)) / \
float(warmup_steps)
else:
inc_each_step = 0
for i in range(total_steps):
if i < warmup_steps:
lr = float(lr_init) + inc_each_step * float(i)
else:
base = (1.0 - (float(i) - float(warmup_steps)) /
(float(total_steps) - float(warmup_steps)))
lr = float(lr_max) * base * base
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)
elif lr_decay_mode == 'cosine':
decay_steps = total_steps - warmup_steps
for i in range(total_steps):
if i < warmup_steps:
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
lr = float(lr_init) + lr_inc * (i + 1)
else:
linear_decay = (total_steps - i) / decay_steps
cosine_decay = 0.5 * \
(1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
decayed = linear_decay * cosine_decay + 0.00001
lr = lr_max * decayed
lr_each_step.append(lr)
else:
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
lr = lr_max - (lr_max - lr_end) * \
(i - warmup_steps) / (total_steps - warmup_steps)
lr_each_step.append(lr)
learning_rate = np.array(lr_each_step).astype(np.float32)
return learning_rate

File diff suppressed because it is too large Load Diff

@ -0,0 +1,131 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Train Resnet50_quant on Cifar10"""
import pytest
import numpy as np
from easydict import EasyDict as ed
from mindspore import context
from mindspore import Tensor
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model
from mindspore.train.quant import quant
from mindspore import set_seed
from resnet_quant_manual import resnet50_quant
from dataset import create_dataset
from lr_generator import get_lr
from utils import Monitor, CrossEntropy
config_quant = ed({
"class_num": 10,
"batch_size": 128,
"step_threshold": 20,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 1e-4,
"epoch_size": 1,
"pretrained_epoch_size": 90,
"buffer_size": 1000,
"image_height": 224,
"image_width": 224,
"data_load_mode": "mindata",
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 50,
"save_checkpoint_path": "./",
"warmup_epochs": 0,
"lr_decay_mode": "cosine",
"use_label_smooth": True,
"label_smooth_factor": 0.1,
"lr_init": 0,
"lr_max": 0.005,
})
dataset_path = "/dataset/workspace/mindspore_dataset/cifar-10-batches-bin/"
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def train_on_ascend():
set_seed(1)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
config = config_quant
print("training configure: {}".format(config))
epoch_size = config.epoch_size
# define network
net = resnet50_quant(class_num=config.class_num)
net.set_train(True)
# define loss
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropy(
smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
#loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
# define dataset
dataset = create_dataset(dataset_path=dataset_path,
config=config,
repeat_num=1,
batch_size=config.batch_size)
step_size = dataset.get_dataset_size()
# convert fusion network to quantization aware network
net = quant.convert_quant_network(net,
bn_fold=True,
per_channel=[True, False],
symmetric=[True, False])
# get learning rate
lr = Tensor(get_lr(lr_init=config.lr_init,
lr_end=0.0,
lr_max=config.lr_max,
warmup_epochs=config.warmup_epochs,
total_epochs=config.epoch_size,
steps_per_epoch=step_size,
lr_decay_mode='cosine'))
# define optimization
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
config.weight_decay, config.loss_scale)
# define model
#model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'})
model = Model(net, loss_fn=loss, optimizer=opt)
print("============== Starting Training ==============")
monitor = Monitor(lr_init=lr.asnumpy(),
step_threshold=config.step_threshold)
callbacks = [monitor]
model.train(epoch_size, dataset, callbacks=callbacks,
dataset_sink_mode=False)
print("============== End Training ==============")
expect_avg_step_loss = 2.40
avg_step_loss = np.mean(np.array(monitor.losses))
print("average step loss:{}".format(avg_step_loss))
assert avg_step_loss < expect_avg_step_loss
if __name__ == '__main__':
train_on_ascend()

@ -0,0 +1,105 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Resnet50 utils"""
import time
import numpy as np
from mindspore.train.callback import Callback
from mindspore import Tensor
from mindspore import nn
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
class Monitor(Callback):
"""
Monitor loss and time.
Args:
lr_init (numpy array): train lr
Returns:
None
Examples:
>>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy())
"""
def __init__(self, lr_init=None, step_threshold=10):
super(Monitor, self).__init__()
self.lr_init = lr_init
self.lr_init_len = len(lr_init)
self.step_threshold = step_threshold
def epoch_begin(self, run_context):
self.losses = []
self.epoch_time = time.time()
def epoch_end(self, run_context):
cb_params = run_context.original_args()
epoch_mseconds = (time.time() - self.epoch_time) * 1000
per_step_mseconds = epoch_mseconds / cb_params.batch_num
print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:8.6f}".format(epoch_mseconds,
per_step_mseconds,
np.mean(self.losses)))
self.epoch_mseconds = epoch_mseconds
def step_begin(self, run_context):
self.step_time = time.time()
def step_end(self, run_context):
cb_params = run_context.original_args()
step_mseconds = (time.time() - self.step_time) * 1000
step_loss = cb_params.net_outputs
if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor):
step_loss = step_loss[0]
if isinstance(step_loss, Tensor):
step_loss = np.mean(step_loss.asnumpy())
self.losses.append(step_loss)
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num
print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:8.6f}/{:8.6f}], time:[{:5.3f}], lr:[{:5.5f}]".format(
cb_params.cur_epoch_num, cb_params.epoch_num, cur_step_in_epoch +
1, cb_params.batch_num, step_loss,
np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1]))
if cb_params.cur_step_num == self.step_threshold:
run_context.request_stop()
class CrossEntropy(_Loss):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits"""
def __init__(self, smooth_factor=0, num_classes=1001):
super(CrossEntropy, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor /
(num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
def construct(self, logit, label):
one_hot_label = self.onehot(label, F.shape(
logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, one_hot_label)
loss = self.mean(loss, 0)
return loss
Loading…
Cancel
Save