diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 56dfbf235e..20712a2564 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CORE_OPERATOR_OPS_H_ -#define MINDSPORE_CORE_OPERATOR_OPS_H_ +#ifndef MINDSPORE_CORE_BASE_CORE_OPS_H_ +#define MINDSPORE_CORE_BASE_CORE_OPS_H_ #include #include @@ -182,6 +182,7 @@ inline const PrimitivePtr kPrimReverseV2 = std::make_shared("ReverseV inline const PrimitivePtr kPrimReverseSequence = std::make_shared("ReverseSequence"); inline const PrimitivePtr kPrimRank = std::make_shared("Rank"); inline const PrimitivePtr kPrimResizeBilinear = std::make_shared("ResizeBilinear"); +inline const PrimitivePtr kPrimResizeGrad = std::make_shared("ResizeGrad"); // NN inline const PrimitivePtr kPrimAdam = std::make_shared("Adam"); @@ -245,7 +246,6 @@ inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput = std::make_shared("DepthwiseConv2dNativeBackpropInput"); inline const PrimitivePtr kPrimDetectionPostProcess = std::make_shared("DetectionPostProcess"); inline const PrimitivePtr kPrimBiasAdd = std::make_shared("BiasAdd"); -inline const PrimitivePtr kPrimBiasGrad = std::make_shared("BiasGrad"); inline const PrimitivePtr kPrimBiasAddGrad = std::make_shared("BiasAddGrad"); inline const PrimitivePtr kPrimBiasSubGrad = std::make_shared("BiasSubGrad"); inline const PrimitivePtr kPrimBinaryCrossEntropy = std::make_shared("BinaryCrossEntropy"); @@ -390,6 +390,7 @@ inline const PrimitivePtr kPrimRound = std::make_shared("Round"); inline const PrimitivePtr kPrimExp = std::make_shared("Exp"); inline const PrimitivePtr kPrimLog = std::make_shared("Log"); inline const PrimitivePtr kPrimRsqrt = std::make_shared("Rsqrt"); +inline const PrimitivePtr kPrimRsqrtGrad = std::make_shared("RsqrtGrad"); inline const PrimitivePtr kPrimSplitV = std::make_shared("SplitV"); inline const PrimitivePtr kPrimLinSpace = std::make_shared("LinSpace"); inline const PrimitivePtr kPrimNonMaxSuppression = std::make_shared("NonMaxSuppression"); @@ -551,4 +552,4 @@ using DoSignaturePrimitivePtr = std::shared_ptr; } // namespace prim } // namespace mindspore -#endif // MINDSPORE_CORE_OPERATOR_OPS_H_ +#endif // MINDSPORE_CORE_BASE_CORE_OPS_H_ diff --git a/mindspore/core/ops/grad/layer_norm_grad.cc b/mindspore/core/ops/grad/layer_norm_grad.cc new file mode 100644 index 0000000000..6943f79ebd --- /dev/null +++ b/mindspore/core/ops/grad/layer_norm_grad.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/grad/layer_norm_grad.h" +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +void LayerNormGrad::Init(const int64_t begin_norm_axis, const int64_t begin_params_axis) { + this->set_begin_norm_axis(begin_norm_axis); + this->set_begin_params_axis(begin_params_axis); +} +void LayerNormGrad::set_begin_norm_axis(const int64_t begin_norm_axis) { + this->AddAttr(kBeginNormAxis, MakeValue(begin_norm_axis)); +} +void LayerNormGrad::set_begin_params_axis(const int64_t begin_params_axis) { + this->AddAttr(kBeginParamsAxis, MakeValue(begin_params_axis)); +} +int64_t LayerNormGrad::get_begin_norm_axis() const { + auto value_ptr = this->GetAttr(kBeginNormAxis); + return GetValue(value_ptr); +} +int64_t LayerNormGrad::get_begin_params_axis() const { + auto value_ptr = this->GetAttr(kBeginParamsAxis); + return GetValue(value_ptr); +} +REGISTER_PRIMITIVE_C(kNameLayerNormGrad, LayerNormGrad); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/grad/layer_norm_grad.h b/mindspore/core/ops/grad/layer_norm_grad.h new file mode 100644 index 0000000000..2400841486 --- /dev/null +++ b/mindspore/core/ops/grad/layer_norm_grad.h @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_GRAD_LAYER_NORM_GRAD_H_ +#define MINDSPORE_CORE_OPS_GRAD_LAYER_NORM_GRAD_H_ +#include + +#include "ops/primitive_c.h" +#include "abstract/abstract_value.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameLayerNormGrad = "LayerNormGrad"; +class LayerNormGrad : public PrimitiveC { + public: + LayerNormGrad() : PrimitiveC(kNameLayerNormGrad) {} + explicit LayerNormGrad(const std::string k_name) : PrimitiveC(k_name) {} + ~LayerNormGrad() = default; + MS_DECLARE_PARENT(LayerNormGrad, PrimitiveC); + void Init(const int64_t begin_norm_axis = 1, const int64_t begin_params_axis = 1); + void set_begin_norm_axis(const int64_t begin_norm_axis); + void set_begin_params_axis(const int64_t begin_params_axis); + int64_t get_begin_norm_axis() const; + int64_t get_begin_params_axis() const; +}; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_GRAD_LAYER_NORM_GRAD_H_ diff --git a/mindspore/core/ops/grad/resize_grad.cc b/mindspore/core/ops/grad/resize_grad.cc new file mode 100644 index 0000000000..39c36631c8 --- /dev/null +++ b/mindspore/core/ops/grad/resize_grad.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/grad/resize_grad.h" +#include +#include +#include +#include +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/primitive_infer_map.h" + +namespace mindspore { +namespace ops { +void ResizeGrad::Init(const ResizeMethod method, const bool align_corners) { + this->set_method(method); + this->set_align_corners(align_corners); +} + +void ResizeGrad::set_method(const ResizeMethod method) { + auto swi = (int64_t)method; + this->AddAttr(kMethod, MakeValue(swi)); +} + +void ResizeGrad::set_align_corners(const bool align_corners) { this->AddAttr(kAlignCorners, MakeValue(align_corners)); } + +ResizeMethod ResizeGrad::get_method() const { + auto value_ptr = GetAttr(kMethod); + return ResizeMethod(GetValue(value_ptr)); +} + +bool ResizeGrad::get_align_corners() const { + auto value_ptr = GetAttr(kAlignCorners); + return GetValue(value_ptr); +} + +REGISTER_PRIMITIVE_C(kNameResizeGrad, ResizeGrad); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/grad/resize_grad.h b/mindspore/core/ops/grad/resize_grad.h new file mode 100644 index 0000000000..da41b61f0d --- /dev/null +++ b/mindspore/core/ops/grad/resize_grad.h @@ -0,0 +1,46 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_GRAD_RESIZE_GRAD_H_ +#define MINDSPORE_CORE_OPS_GRAD_RESIZE_GRAD_H_ +#include +#include +#include "ops/primitive_c.h" +#include "abstract/abstract_value.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameResizeGrad = "ResizeGrad"; +class ResizeGrad : public PrimitiveC { + public: + ResizeGrad() : PrimitiveC(kNameResizeGrad) {} + ~ResizeGrad() = default; + MS_DECLARE_PARENT(ResizeGrad, PrimitiveC); + void Init(const ResizeMethod method, const bool align_corners); + void set_method(const ResizeMethod method); + void set_align_corners(const bool align_corners); + ResizeMethod get_method() const; + bool get_align_corners() const; +}; + +AbstractBasePtr ResizeGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimResizeGradPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_GRAD_RESIZE_GRAD_H_ diff --git a/mindspore/core/ops/grad/rsqrt_grad.cc b/mindspore/core/ops/grad/rsqrt_grad.cc new file mode 100644 index 0000000000..4b2f090f99 --- /dev/null +++ b/mindspore/core/ops/grad/rsqrt_grad.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/grad/rsqrt_grad.h" +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/primitive_infer_map.h" + +namespace mindspore { +namespace ops { +REGISTER_PRIMITIVE_C(kNameRsqrtGrad, RsqrtGrad); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/grad/rsqrt_grad.h b/mindspore/core/ops/grad/rsqrt_grad.h new file mode 100644 index 0000000000..df6f9795fb --- /dev/null +++ b/mindspore/core/ops/grad/rsqrt_grad.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_GRAD_RSQRT_GRAD_H_ +#define MINDSPORE_CORE_OPS_GRAD_RSQRT_GRAD_H_ +#include "ops/primitive_c.h" +#include "abstract/abstract_value.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameRsqrtGrad = "RsqrtGrad"; +class RsqrtGrad : public PrimitiveC { + public: + RsqrtGrad() : PrimitiveC(kNameRsqrtGrad) { InitIOName({"out_backprop", "input"}, {"output"}); } + ~RsqrtGrad() = default; + MS_DECLARE_PARENT(RsqrtGrad, PrimitiveC); + void Init() {} +}; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_GRAD_RSQRT_GRAD_H_ diff --git a/mindspore/core/ops/grad/sqrt_grad.cc b/mindspore/core/ops/grad/sqrt_grad.cc new file mode 100644 index 0000000000..3aefff625e --- /dev/null +++ b/mindspore/core/ops/grad/sqrt_grad.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/grad/sqrt_grad.h" +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/primitive_infer_map.h" + +namespace mindspore { +namespace ops { +REGISTER_PRIMITIVE_C(kNameSqrtGrad, SqrtGrad); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/grad/sqrt_grad.h b/mindspore/core/ops/grad/sqrt_grad.h new file mode 100644 index 0000000000..4ff484fc8b --- /dev/null +++ b/mindspore/core/ops/grad/sqrt_grad.h @@ -0,0 +1,36 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_GRAD_SQRT_GRAD_H_ +#define MINDSPORE_CORE_OPS_GRAD_SQRT_GRAD_H_ +#include "ops/primitive_c.h" +#include "abstract/abstract_value.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameSqrtGrad = "SqrtGrad"; +class SqrtGrad : public PrimitiveC { + public: + SqrtGrad() : PrimitiveC(kNameSqrtGrad) { InitIOName({"out_backprop", "input"}, {"output"}); } + ~SqrtGrad() = default; + MS_DECLARE_PARENT(SqrtGrad, PrimitiveC); + void Init() {} +}; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_GRAD_SQRT_GRAD_H_ diff --git a/mindspore/lite/examples/export_models/models/densenet_train_export.py b/mindspore/lite/examples/export_models/models/densenet_train_export.py index 295b724ad8..20bd76f352 100644 --- a/mindspore/lite/examples/export_models/models/densenet_train_export.py +++ b/mindspore/lite/examples/export_models/models/densenet_train_export.py @@ -15,13 +15,20 @@ """densenet_train_export.""" import sys +import os import numpy as np from train_utils import SaveInOut, TrainWrap -from official.cv.densenet121.src.network.densenet import DenseNet121 import mindspore.common.dtype as mstype from mindspore import context, Tensor, nn from mindspore.train.serialization import export +sys.path.append(os.environ['CLOUD_MODEL_ZOO'] + 'official/cv/densenet121/') +#pylint: disable=wrong-import-position +from official.cv.densenet121.src.network.densenet import DenseNet121 + + + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False) n = DenseNet121(num_classes=10) diff --git a/mindspore/lite/examples/export_models/models/resnet_train_export.py b/mindspore/lite/examples/export_models/models/resnet_train_export.py index 05b1856379..c0dbe90555 100644 --- a/mindspore/lite/examples/export_models/models/resnet_train_export.py +++ b/mindspore/lite/examples/export_models/models/resnet_train_export.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""mobilenetv2_train_export.""" +"""resnet_train_export""" import sys import numpy as np diff --git a/mindspore/lite/examples/export_models/models/vgg_train_export.py b/mindspore/lite/examples/export_models/models/vgg_train_export.py new file mode 100644 index 0000000000..007825283a --- /dev/null +++ b/mindspore/lite/examples/export_models/models/vgg_train_export.py @@ -0,0 +1,39 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""vgg_train_export.""" + +import sys +import numpy as np +from train_utils import SaveInOut, TrainWrap +from official.cv.vgg16.src.vgg import vgg16 +import mindspore.common.dtype as mstype +from mindspore import context, Tensor, nn +from mindspore.train.serialization import export + +context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False) + +batch = 2 + +n = vgg16(num_classes=10) +loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False) +optimizer = nn.Momentum(n.trainable_params(), 0.01, 0.9, use_nesterov=False) +net = TrainWrap(n, loss_fn, optimizer) + +x = Tensor(np.random.randn(batch, 3, 224, 224), mstype.float32) +label = Tensor(np.zeros([batch, 10]).astype(np.float32)) +export(net, x, label, file_name="mindir/vgg_train", file_format='MINDIR') + +if len(sys.argv) > 1: + SaveInOut(sys.argv[1] + "vgg", x, label, n, net) diff --git a/mindspore/lite/examples/export_models/models/xception_train_export.py b/mindspore/lite/examples/export_models/models/xception_train_export.py new file mode 100644 index 0000000000..6b82b3bb05 --- /dev/null +++ b/mindspore/lite/examples/export_models/models/xception_train_export.py @@ -0,0 +1,42 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""inceptionv4_train_export""" + +import sys +import numpy as np +from train_utils import SaveInOut, TrainWrap +from official.cv.xception.src.Xception import Xception +import mindspore.common.dtype as mstype +from mindspore import context, Tensor, nn +from mindspore.train.serialization import export + +context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False) + + +n = Xception(num_classes=1000) +n.dropout = nn.Dropout(keep_prob=1.0) + +loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False) +optimizer = nn.SGD(n.trainable_params(), learning_rate=0.01, momentum=0.9, dampening=0.0, weight_decay=0.0, + nesterov=True, loss_scale=1.0) +net = TrainWrap(n, loss_fn, optimizer) + +batch = 2 +x = Tensor(np.random.randn(batch, 3, 299, 299), mstype.float32) +label = Tensor(np.zeros([batch, 1000]).astype(np.float32)) +export(net, x, label, file_name="mindir/xception_train", file_format='MINDIR') + +if len(sys.argv) > 1: + SaveInOut(sys.argv[1] + "xception", x, label, n, net) diff --git a/mindspore/lite/examples/export_models/models_train.cfg b/mindspore/lite/examples/export_models/models_train.cfg index 00e4296fd2..24cb443e6e 100644 --- a/mindspore/lite/examples/export_models/models_train.cfg +++ b/mindspore/lite/examples/export_models/models_train.cfg @@ -8,5 +8,7 @@ effnet_tune resnet googlenet nin -#shufflenetv2 -#densenet +densenet +shufflenetv2 +vgg noarm32 +xception diff --git a/mindspore/lite/examples/export_models/prepare.sh b/mindspore/lite/examples/export_models/prepare.sh index d0dd47320b..172691eaa8 100755 --- a/mindspore/lite/examples/export_models/prepare.sh +++ b/mindspore/lite/examples/export_models/prepare.sh @@ -2,10 +2,9 @@ display_usage() { - echo "Usage: prepare.sh [-d mindspore_docker] [-r release.tar.gz] [-i]" + echo "Usage: prepare.sh [-d mindspore_docker] [-i]" echo "Options:" echo " -d docker where mindspore is installed. If no docker is provided script will use local python" - echo " -r release tarball" echo " -i create input and output files" } @@ -20,9 +19,6 @@ checkopts() d) DOCKER=$OPTARG ;; - r) - TARBALL=$OPTARG - ;; i) TRAIN_IO="train_io/" ;; @@ -55,16 +51,6 @@ echo ' ' > ${export_result_file} CLOUD_MODEL_ZOO=../../../../model_zoo/ checkopts "$@" -if [ "$TARBALL" == "" ]; then - file=$(ls ../../../../output/mindspore-lite-*-train-linux-x64.tar.gz) - if [ -f ${file} ]; then - TARBALL=${file} - else - echo "release.tar.gz was not found" - display_usage - exit 1 - fi -fi if [ -z "${DOCKER}" ]; then echo "MindSpore docker was not provided, attempting to run locally" @@ -76,13 +62,14 @@ if [ ! -z "${TRAIN_IO}" ]; then fi while read line; do - model_name=${line} + LFS=" " read -r -a line_array <<< ${line} + model_name=${line_array[0]} if [[ $model_name == \#* ]]; then - continue + continue fi echo 'exporting' ${model_name} if [ ! -z "${DOCKER}" ]; then - docker run -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER --privileged=true ${DOCKER} /bin/bash -c "PYTHONPATH=${CLOUD_MODEL_ZOO} python models/${model_name}_train_export.py ${TRAIN_IO} && chmod 444 mindir/${model_name}_train.mindir" + docker run -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER --privileged=true ${DOCKER} /bin/bash -c "CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} PYTHONPATH=${CLOUD_MODEL_ZOO} python models/${model_name}_train_export.py ${TRAIN_IO} && chmod 444 mindir/${model_name}_train.mindir" else PYTHONPATH=${CLOUD_MODEL_ZOO} python models/${model_name}_train_export.py ${TRAIN_IO} fi diff --git a/mindspore/lite/examples/train_lenet/Makefile b/mindspore/lite/examples/train_lenet/Makefile index 7e2b69cf4b..b1990f7fe5 100644 --- a/mindspore/lite/examples/train_lenet/Makefile +++ b/mindspore/lite/examples/train_lenet/Makefile @@ -2,8 +2,12 @@ BASE_DIR=$(realpath ../../../../) APP:=bin/net_runner MSLIB:=mindspore-lite LMDLIB:=-lminddata-lite -ljpeg -LHIAILIB:=-lhiai_ir_build -lhiai_ir -lhiai MSDIR:=$(realpath package-$(TARGET)/lib) +ifneq ("$(wildcard $(MSDIR)/libhiai.so)","") + LHIAILIB:=-lhiai_ir_build -lhiai_ir -lhiai +else + LHIAILIB:= +endif SRC:=src/net_runner.cc OBJ:=$(SRC:.cc=.o) diff --git a/mindspore/lite/examples/train_lenet/src/net_runner.cc b/mindspore/lite/examples/train_lenet/src/net_runner.cc index a8690c490a..2eab634590 100644 --- a/mindspore/lite/examples/train_lenet/src/net_runner.cc +++ b/mindspore/lite/examples/train_lenet/src/net_runner.cc @@ -96,7 +96,8 @@ void NetRunner::InitAndFigureInputs() { context.device_list_[0].device_type_ = mindspore::lite::DT_CPU; context.thread_num_ = 2; - loop_ = mindspore::session::TrainLoop::CreateTrainLoop(ms_file_, &context); + auto session = mindspore::session::TrainSession::CreateSession(ms_file_, &context); + loop_ = mindspore::session::TrainLoop::CreateTrainLoop(session, &context); session_ = loop_->train_session(); MS_ASSERT(nullptr != session_); diff --git a/mindspore/lite/examples/transfer_learning/.gitignore b/mindspore/lite/examples/transfer_learning/.gitignore new file mode 100644 index 0000000000..e37f8f454f --- /dev/null +++ b/mindspore/lite/examples/transfer_learning/.gitignore @@ -0,0 +1,5 @@ +*.mindir +*.ms +msl +package-* +dataset diff --git a/mindspore/lite/include/train/train_loop.h b/mindspore/lite/include/train/train_loop.h index 1b8dc54f41..d31cbf0b28 100644 --- a/mindspore/lite/include/train/train_loop.h +++ b/mindspore/lite/include/train/train_loop.h @@ -40,11 +40,11 @@ class TrainLoop { public: /// \brief Static method to create a TrainLoop object /// - /// \param[in] filename Filename to read flatbuffer from + /// \param[in] train_session Train session object as return from CreateSession\CreateTransferSession API /// \param[in] context Defines the context of the session to be created /// /// \return Pointer of MindSpore Lite TrainLoop - static TrainLoop *CreateTrainLoop(const std::string &model_filename, lite::Context *context, int batch_size = -1); + static TrainLoop *CreateTrainLoop(session::TrainSession *train_session, lite::Context *context, int batch_size = -1); /// \brief Class destructor virtual ~TrainLoop() = default; diff --git a/mindspore/lite/include/train/train_loop_callback.h b/mindspore/lite/include/train/train_loop_callback.h index 4c17ac1d40..a357e08673 100644 --- a/mindspore/lite/include/train/train_loop_callback.h +++ b/mindspore/lite/include/train/train_loop_callback.h @@ -44,11 +44,40 @@ constexpr int RET_EXIT = 2; class TrainLoopCallBack { public: virtual ~TrainLoopCallBack() = default; + + /// \brief This method is called once before the network executing + /// + /// \param[in] cb_data info about current execution virtual void Begin(const TrainLoopCallBackData &cb_data) {} + + /// \brief This method is called once following the network execution + /// + /// \param[in] cb_data info about current execution virtual void End(const TrainLoopCallBackData &cb_data) {} + + /// \brief This method is called at the beginning of each epoch + /// + /// \param[in] cb_data info about current execution virtual void EpochBegin(const TrainLoopCallBackData &cb_data) {} + + /// \brief This method is called after the run of each epoch + /// + /// \param[in] cb_data info about current execution + /// + /// \return indication if to continue in the train loop: + /// RET_CONTINUE -- continue training + /// RET_STOP_TRAINING -- stop training (e.g., due to achieved accuracy) + /// RET_EXIT -- Exit training (due to error of some sort) virtual int EpochEnd(const TrainLoopCallBackData &cb_data) { return RET_CONTINUE; } + + /// \brief This method is called at the beginning of each step + /// + /// \param[in] cb_data info about current execution virtual void StepBegin(const TrainLoopCallBackData &cb_data) {} + + /// \brief This method is called after each step is ran + /// + /// \param[in] cb_data info about current execution virtual void StepEnd(const TrainLoopCallBackData &cb_data) {} }; diff --git a/mindspore/lite/java/java/common/src/main/java/com/mindspore/lite/TrainSession.java b/mindspore/lite/java/java/common/src/main/java/com/mindspore/lite/TrainSession.java index d76ec714e3..e092a53b96 100644 --- a/mindspore/lite/java/java/common/src/main/java/com/mindspore/lite/TrainSession.java +++ b/mindspore/lite/java/java/common/src/main/java/com/mindspore/lite/TrainSession.java @@ -142,6 +142,14 @@ public class TrainSession { return this.setLearningRate(this.sessionPtr, learning_rate); } + public boolean setupVirtualBatch(int virtualBatchMultiplier, float learningRate, float momentum) { + return this.setupVirtualBatch(this.sessionPtr, virtualBatchMultiplier, learningRate, momentum); + } + + public boolean setupVirtualBatch(int virtualBatchMultiplier) { + return this.setupVirtualBatch(this.sessionPtr, virtualBatchMultiplier, -1.0f, -1.0f); + } + private native long createSession(String modelFilename, long msConfigPtr); private native void bindThread(long sessionPtr, boolean if_bind); @@ -175,4 +183,6 @@ public class TrainSession { private native boolean isEval(long sessionPtr); private native boolean setLearningRate(long sessionPtr, float learning_rate); + + private native boolean setupVirtualBatch(long sessionPtr, int virtualBatchMultiplier, float learningRate, float momentum); } diff --git a/mindspore/lite/java/native/runtime/train_session.cpp b/mindspore/lite/java/native/runtime/train_session.cpp index c37d734c7e..84f13d17fd 100644 --- a/mindspore/lite/java/native/runtime/train_session.cpp +++ b/mindspore/lite/java/native/runtime/train_session.cpp @@ -303,3 +303,18 @@ extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_setLe auto ret = train_session_ptr->SetLearningRate(learning_rate); return (jboolean)(ret == mindspore::lite::RET_OK); } + +extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_setupVirtualBatch(JNIEnv *env, jobject thiz, + jlong session_ptr, + jint virtualBatchMultiplier, + jfloat learningRate, + jfloat momentum) { + auto *session_pointer = reinterpret_cast(session_ptr); + if (session_pointer == nullptr) { + MS_LOGE("Session pointer from java is nullptr"); + return (jboolean) false; + } + auto *train_session_ptr = static_cast(session_pointer); + auto ret = train_session_ptr->SetupVirtualBatch(virtualBatchMultiplier, learningRate, momentum); + return (jboolean)(ret == mindspore::lite::RET_OK); +} diff --git a/mindspore/lite/micro/cmake/file_list.cmake b/mindspore/lite/micro/cmake/file_list.cmake index 697a174713..22b84b341a 100644 --- a/mindspore/lite/micro/cmake/file_list.cmake +++ b/mindspore/lite/micro/cmake/file_list.cmake @@ -244,6 +244,7 @@ set(LITE_KERNEL_SRC ${LITE_DIR}/nnacl/infer/hashtable_lookup_infer.c ${LITE_DIR}/nnacl/infer/invert_permutation_infer.c ${LITE_DIR}/nnacl/infer/layer_norm_infer.c + ${LITE_DIR}/nnacl/infer/layer_norm_grad_infer.c ${LITE_DIR}/nnacl/infer/lin_space_infer.c ${LITE_DIR}/nnacl/infer/lsh_projection_infer.c ${LITE_DIR}/nnacl/infer/lstm_infer.c diff --git a/mindspore/lite/nnacl/fp32/layer_norm_fp32.c b/mindspore/lite/nnacl/fp32/layer_norm_fp32.c index c05fda2ea1..6781f657df 100644 --- a/mindspore/lite/nnacl/fp32/layer_norm_fp32.c +++ b/mindspore/lite/nnacl/fp32/layer_norm_fp32.c @@ -65,11 +65,10 @@ void LayerNormGammaAndBeta(float *dst, const float *src, const float *gamma_data } int LayerNorm(const float *src_data, const float *gamma_data, const float *beta_data, float *dst_data, - LayerNormParameter *param, size_t task_id) { + LayerNormParameter *param, float *out_mean, float *out_deno, size_t task_id) { if (src_data == NULL || dst_data == NULL || gamma_data == NULL || beta_data == NULL) { return NNACL_NULL_PTR; } - int step = UP_DIV(param->norm_outer_size_, param->op_parameter_.thread_num_); int thread_end = MSMIN((task_id + 1) * step, param->norm_outer_size_); for (int i = task_id * step; i < thread_end; i++) { @@ -79,7 +78,10 @@ int LayerNorm(const float *src_data, const float *gamma_data, const float *beta_ float square_mean = 0.0f; LayerNormMeanAndSquare(src_norm, param->norm_inner_size_, &mean, &square_mean); const float deno = 1 / sqrtf(square_mean - mean * mean + param->epsilon_); - + if ((out_mean != NULL) && (out_deno != NULL)) { + out_mean[i] = mean; + out_deno[i] = deno; + } if (param->norm_outer_size_ <= param->params_outer_size_) { for (int x = 0; x < param->norm_inner_size_ / param->params_inner_size_; x++) { const float *src_param = src_norm + x * param->params_inner_size_; diff --git a/mindspore/lite/nnacl/fp32/layer_norm_fp32.h b/mindspore/lite/nnacl/fp32/layer_norm_fp32.h index 44a47cbc16..07c3fc2955 100644 --- a/mindspore/lite/nnacl/fp32/layer_norm_fp32.h +++ b/mindspore/lite/nnacl/fp32/layer_norm_fp32.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_LITE_NNACL_FP32_LAYER_NORM_H_ -#define MINDSPORE_LITE_NNACL_FP32_LAYER_NORM_H_ +#ifndef MINDSPORE_LITE_NNACL_FP32_LAYER_NORM_FP32_H_ +#define MINDSPORE_LITE_NNACL_FP32_LAYER_NORM_FP32_H_ #include "nnacl/op_base.h" #include "nnacl/layer_norm_parameter.h" @@ -24,9 +24,9 @@ extern "C" { #endif int LayerNorm(const float *src_data, const float *gamma_data, const float *beta_data, float *dst_data, - LayerNormParameter *param, size_t task_id); + LayerNormParameter *param, float *out_mean, float *out_deno, size_t task_id); #ifdef __cplusplus } #endif -#endif // MINDSPORE_LITE_NNACL_FP32_LAYER_NORM_H_ +#endif // MINDSPORE_LITE_NNACL_FP32_LAYER_NORM_FP32_H_ diff --git a/mindspore/lite/nnacl/fp32_grad/activation_grad.c b/mindspore/lite/nnacl/fp32_grad/activation_grad.c index cb713a550b..ff507f917b 100644 --- a/mindspore/lite/nnacl/fp32_grad/activation_grad.c +++ b/mindspore/lite/nnacl/fp32_grad/activation_grad.c @@ -95,3 +95,18 @@ int HSigmoidGrad(float *src0, float *src1, size_t length, float *dst) { } return NNACL_OK; } + +int EluGrad(float *src0, float *src1, size_t length, float *dst, float alpha) { + for (size_t i = 0; i < length; ++i) { + dst[i] = (src1[i] > 0.0f ? src0[i] : alpha * expm1(src1[i]) * src0[i]); + } + return NNACL_OK; +} + +int GeluGrad(float *src0, float *src1, size_t length, float *dst) { + for (size_t i = 0; i < length; ++i) { + dst[i] = src0[i] * ((0.5 * (1.0 + erf(src1[i] / 1.4142135623730951))) + + (src1[i] * exp(-0.5 * src1[i] * src1[i]) / 2.5066282746)); + } + return NNACL_OK; +} diff --git a/mindspore/lite/nnacl/fp32_grad/activation_grad.h b/mindspore/lite/nnacl/fp32_grad/activation_grad.h index 7aa8f755c4..f6a3db303f 100644 --- a/mindspore/lite/nnacl/fp32_grad/activation_grad.h +++ b/mindspore/lite/nnacl/fp32_grad/activation_grad.h @@ -37,6 +37,8 @@ int SigmoidGrad(float *src0, float *src1, size_t length, float *dst); int TanhGrad(float *src0, float *src1, size_t length, float *dst); int HSwishGrad(float *src0, float *src1, size_t length, float *dst); int HSigmoidGrad(float *src0, float *src1, size_t length, float *dst); +int EluGrad(float *src0, float *src1, size_t length, float *dst, float alpha); +int GeluGrad(float *src0, float *src1, size_t length, float *dst); #ifdef __cplusplus } diff --git a/mindspore/lite/nnacl/fp32_grad/arithmetic_grad.c b/mindspore/lite/nnacl/fp32_grad/arithmetic_grad.c index 72784362df..63d6644188 100644 --- a/mindspore/lite/nnacl/fp32_grad/arithmetic_grad.c +++ b/mindspore/lite/nnacl/fp32_grad/arithmetic_grad.c @@ -16,6 +16,7 @@ #include "nnacl/fp32_grad/arithmetic_grad.h" #include +#include #include "nnacl/fp32_grad/utils.h" #include "nnacl/errorcode.h" @@ -137,3 +138,17 @@ void MinimumByAxes(const float *input0, const float *input1, const float *dy, co } while (NextIndex(num_dims, dy_dims, input_iter)); } } + +int ElementSqrtGrad(const float *in1, const float *in2, float *out, const int element_size) { + for (int i = 0; i < element_size; i++) { + out[i] = 0.5f * in2[i] / in1[i]; + } + return NNACL_OK; +} + +int ElementRsqrtGrad(const float *in1, const float *in2, float *out, const int element_size) { + for (int i = 0; i < element_size; i++) { + out[i] = -0.5f * in2[i] * in1[i] * in1[1] * in1[i]; + } + return NNACL_OK; +} diff --git a/mindspore/lite/nnacl/fp32_grad/arithmetic_grad.h b/mindspore/lite/nnacl/fp32_grad/arithmetic_grad.h index 948fea90f3..f4077634a7 100644 --- a/mindspore/lite/nnacl/fp32_grad/arithmetic_grad.h +++ b/mindspore/lite/nnacl/fp32_grad/arithmetic_grad.h @@ -28,6 +28,9 @@ void MaximumByAxes(const float *input0, const float *input1, const float *dy, co const int *input1_dims, const int *dy_dims, float *output0, float *output1, int num_dims); void MinimumByAxes(const float *input0, const float *input1, const float *dy, const int *input0_dims, const int *input1_dims, const int *dy_dims, float *output0, float *output1, int num_dims); +int ElementSqrtGrad(const float *in1, const float *in2, float *out, const int element_size); +int ElementRsqrtGrad(const float *in1, const float *in2, float *out, const int element_size); + #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/fp32_grad/layernorm_grad.c b/mindspore/lite/nnacl/fp32_grad/layernorm_grad.c new file mode 100644 index 0000000000..38357d39c4 --- /dev/null +++ b/mindspore/lite/nnacl/fp32_grad/layernorm_grad.c @@ -0,0 +1,57 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32_grad/layernorm_grad.h" +#include +#include + +void LayerNormGrad(const float *x, const float *dy, const float *var, const float *mean, const float *gamma, + int param_num, int param_size, int block_num, int block_size, float *dx, float *dg, float *db) { + // var is actually 1/sqrf(var)-> var^0.5 + const float *var_sqrt_rev = var; + for (size_t i = 0; i < param_num; ++i) { + float dgamma = 0.0f; + float dbeta = 0.0f; + for (size_t j = i; j < param_size * param_num; j += param_num) { + int norm_shift = (int)(j / block_size); + dgamma += dy[j] * var_sqrt_rev[norm_shift] * (x[j] - mean[norm_shift]); + dbeta += dy[j]; + } + dg[i] = dgamma; + db[i] = dbeta; + } + for (size_t i = 0; i < block_num; ++i) { + float sum1 = 0.0f; + float sum2 = 0.0f; + float sum3 = 0.0f; + for (size_t j = i * block_size; j < (i + 1) * block_size; ++j) { + int param_shift = j % param_num; + int norm_shift = (int)(j / block_size); + float dxm = x[j] - mean[norm_shift]; + float dyg = dy[j] * gamma[param_shift]; + sum1 += -0.5f * dyg * dxm * var_sqrt_rev[norm_shift] * var_sqrt_rev[norm_shift] * var_sqrt_rev[norm_shift]; + sum3 += -2.0f * dxm; + } + for (size_t j = i * block_size; j < (i + 1) * block_size; ++j) { + int param_shift = j % param_num; + int norm_shift = (int)(j / block_size); + float var_sqrt = var_sqrt_rev[norm_shift]; + float dx1 = dy[j] * gamma[param_shift] * var_sqrt; + float dx2 = sum1 * 2.0f / block_size * (x[j] - mean[norm_shift]); + float dx3 = (-1.0f * var_sqrt * sum2 + (1.0f / block_size) * sum1 * sum3) * (1.0f / block_size); + dx[j] = dx1 + dx2 + dx3; + } + } +} diff --git a/mindspore/lite/nnacl/fp32_grad/layernorm_grad.h b/mindspore/lite/nnacl/fp32_grad/layernorm_grad.h new file mode 100644 index 0000000000..3016fced41 --- /dev/null +++ b/mindspore/lite/nnacl/fp32_grad/layernorm_grad.h @@ -0,0 +1,29 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_NNACL_FP32_GRAD_LAYERNORM_GRAD_H_ +#define MINDSPORE_LITE_NNACL_FP32_GRAD_LAYERNORM_GRAD_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +void LayerNormGrad(const float *x, const float *dy, const float *var, const float *mean, const float *gamma, + int param_num, int param_size, int block_num, int block_size, float *dx, float *dg, float *db); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_LITE_NNACL_FP32_GRAD_LAYERNORM_GRAD_H_ diff --git a/mindspore/lite/nnacl/fp32_grad/layernormgrad_parameter.h b/mindspore/lite/nnacl/fp32_grad/layernormgrad_parameter.h new file mode 100644 index 0000000000..fcf35e95bf --- /dev/null +++ b/mindspore/lite/nnacl/fp32_grad/layernormgrad_parameter.h @@ -0,0 +1,27 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_NNACL_FP32_GRAD_LAYERNORMGRAD_PARAMETER_H_ +#define MINDSPORE_LITE_NNACL_FP32_GRAD_LAYERNORMGRAD_PARAMETER_H_ + +#include "nnacl/op_base.h" + +typedef struct LayerNormGradParameter { + OpParameter op_parameter_; + int begin_norm_axis_; + int begin_params_axis_; +} LayerNormGradParameter; + +#endif // MINDSPORE_LITE_NNACL_FP32_GRAD_LAYERNORMGRAD_PARAMETER_H_ diff --git a/mindspore/lite/nnacl/fp32_grad/resize_grad.c b/mindspore/lite/nnacl/fp32_grad/resize_grad.c new file mode 100644 index 0000000000..521e136acd --- /dev/null +++ b/mindspore/lite/nnacl/fp32_grad/resize_grad.c @@ -0,0 +1,84 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32_grad/resize_grad.h" +#include + +void ResizeNearestNeighborGrad(float *in_addr, float *out_addr, int batch_size, int channel, + ResizeGradParameter *param) { + bool align_corners = param->align_corners_; + size_t in_hw_size = param->in_width_ * param->in_height_; + size_t out_hw_size = param->out_width_ * param->out_height_; + + for (int32_t b = 0; b < batch_size; ++b) { + for (size_t i = 0; i < in_hw_size; ++i) { + size_t in_y = i / param->in_width_; + size_t in_x = i % param->in_width_; + for (int32_t c = 0; c < channel; ++c) { + size_t out_y = MSMIN( + (align_corners) ? (size_t)roundf(in_y * param->height_scale_) : (size_t)floorf(in_y * param->height_scale_), + param->out_height_ - 1); + size_t out_x = MSMIN( + (align_corners) ? (size_t)roundf(in_x * param->width_scale_) : (size_t)floorf(in_x * param->width_scale_), + param->out_width_ - 1); + size_t out_offset = out_y * (param->out_width_ * channel) + (out_x * channel) + c; + size_t in_offset = in_y * (param->in_width_ * channel) + (in_x * channel) + c; + out_addr[out_offset] += in_addr[in_offset]; + } + } + out_addr += out_hw_size * channel; + in_addr += in_hw_size * channel; + } +} + +void ResizeBiLinearGrad(float *in_addr, float *out_addr, int batch_size, int channel, ResizeGradParameter *param) { + size_t in_hw_size = param->in_width_ * param->in_height_; + size_t out_hw_size = param->out_width_ * param->out_height_; + + for (int32_t b = 0; b < batch_size; ++b) { + for (size_t i = 0; i < in_hw_size; ++i) { + size_t h = i / param->in_width_; + size_t w = i % param->in_width_; + for (int32_t c = 0; c < channel; ++c) { + float in_y = (float)h * param->height_scale_; + size_t top_y_index = MSMAX((size_t)(floorf(in_y)), (size_t)(0)); + size_t bottom_y_index = MSMIN((size_t)(ceilf(in_y)), param->out_height_ - 1); + float y_lerp = in_y - floorf(in_y); + float inverse_y_lerp = 1.0 - y_lerp; + + float in_x = (float)w * param->width_scale_; + size_t left_x_index = MSMAX((size_t)(floorf(in_x)), (size_t)(0)); + size_t right_x_index = MSMIN((size_t)(ceilf(in_x)), param->out_width_ - 1); + float x_lerp = in_x - floorf(in_x); + float inverse_x_lerp = 1.0 - x_lerp; + + size_t in_offset = h * (param->in_width_ * channel) + (w * channel) + c; + size_t out_offset_top_y_left_x = top_y_index * (param->out_width_ * channel) + (left_x_index * channel) + c; + size_t out_offset_top_y_right_x = top_y_index * (param->out_width_ * channel) + (right_x_index * channel) + c; + size_t out_offset_bottom_y_left_x = + bottom_y_index * (param->out_width_ * channel) + (left_x_index * channel) + c; + size_t out_offset_bottom_y_right_x = + bottom_y_index * (param->out_width_ * channel) + (right_x_index * channel) + c; + + out_addr[out_offset_top_y_left_x] += in_addr[in_offset] * (float)(inverse_y_lerp * inverse_x_lerp); + out_addr[out_offset_top_y_right_x] += in_addr[in_offset] * (float)(inverse_y_lerp * x_lerp); + out_addr[out_offset_bottom_y_left_x] += in_addr[in_offset] * (float)(y_lerp * inverse_x_lerp); + out_addr[out_offset_bottom_y_right_x] += in_addr[in_offset] * (float)(y_lerp * x_lerp); + } + } + out_addr += out_hw_size * channel; + in_addr += in_hw_size * channel; + } +} diff --git a/mindspore/lite/nnacl/fp32_grad/resize_grad.h b/mindspore/lite/nnacl/fp32_grad/resize_grad.h new file mode 100644 index 0000000000..8d7630810e --- /dev/null +++ b/mindspore/lite/nnacl/fp32_grad/resize_grad.h @@ -0,0 +1,44 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_NNACL_FP32_GRAD_RESIZE_GRAD_H_ +#define MINDSPORE_LITE_NNACL_FP32_GRAD_RESIZE_GRAD_H_ + +#include "nnacl/op_base.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct ResizeGradParameter { + OpParameter op_parameter_; + bool align_corners_; + int method; + size_t in_height_; + size_t in_width_; + size_t out_height_; + size_t out_width_; + float height_scale_; + float width_scale_; +} ResizeGradParameter; + +void ResizeNearestNeighborGrad(float *in_addr, float *out_addr, int batch_size, int channel, + ResizeGradParameter *param); +void ResizeBiLinearGrad(float *in_addr, float *out_addr, int batch_size, int channel, ResizeGradParameter *param); +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_LITE_NNACL_FP32_GRAD_RESIZE_GRAD_H_ diff --git a/mindspore/lite/nnacl/fp32_grad/unsorted_segment_sum.c b/mindspore/lite/nnacl/fp32_grad/unsorted_segment_sum.c new file mode 100644 index 0000000000..bfc368d84e --- /dev/null +++ b/mindspore/lite/nnacl/fp32_grad/unsorted_segment_sum.c @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/fp32_grad/unsorted_segment_sum.h" +#include "nnacl/errorcode.h" + +int UnsortedSegmentSum(const float *input, int unit_num, int input_dim1, const int *indices, float *output, + int output_dim0, int output_dim1) { + for (int i = 0; i < unit_num; ++i) { + int j = i / input_dim1; + int k = i % input_dim1; + + int index = indices[j]; + if (index < 0 || index >= output_dim0) { + continue; + } + int output_index = index * output_dim1 + k; + output[output_index] += input[i]; + } + return NNACL_OK; +} diff --git a/mindspore/lite/nnacl/fp32_grad/unsorted_segment_sum.h b/mindspore/lite/nnacl/fp32_grad/unsorted_segment_sum.h new file mode 100644 index 0000000000..c0b891dc74 --- /dev/null +++ b/mindspore/lite/nnacl/fp32_grad/unsorted_segment_sum.h @@ -0,0 +1,29 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_NNACL_FP32_GRAD_UNSORTED_SEGMENT_SUM_H_ +#define MINDSPORE_LITE_NNACL_FP32_GRAD_UNSORTED_SEGMENT_SUM_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +int UnsortedSegmentSum(const float *input, int unit_num, int input_dim1, const int *indices, float *output, + int output_dim0, int output_dim1); +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_LITE_NNACL_FP32_GRAD_UNSORTED_SEGMENT_SUM_H_ diff --git a/mindspore/lite/nnacl/infer/add_sub_grad_infer.c b/mindspore/lite/nnacl/infer/add_sub_grad_infer.c index 7d64d987a6..c61e130147 100644 --- a/mindspore/lite/nnacl/infer/add_sub_grad_infer.c +++ b/mindspore/lite/nnacl/infer/add_sub_grad_infer.c @@ -15,7 +15,7 @@ */ #include "nnacl/infer/add_sub_grad_infer.h" -#include "nnacl/infer/arithmetic_grad_infer.h" +#include "nnacl/arithmetic.h" int AddSubGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, OpParameter *parameter) { @@ -32,35 +32,29 @@ int AddSubGradInferShape(const TensorC *const *inputs, size_t inputs_size, Tenso TensorC *dx1 = outputs[0]; TensorC *dx2 = outputs[1]; - ArithmeticGradParameter *param = (ArithmeticGradParameter *)parameter; + if (!parameter->infer_flag_) { + return NNACL_INFER_INVALID; + } - int in_shape0[MAX_SHAPE_SIZE]; - size_t in_shape0_size = 0; - ShapeSet(in_shape0, &in_shape0_size, x1->shape_, x1->shape_size_); - int in_shape1[MAX_SHAPE_SIZE]; - size_t in_shape1_size = 0; - ShapeSet(in_shape1, &in_shape1_size, x2->shape_, x2->shape_size_); - int outShape[MAX_SHAPE_SIZE]; - size_t outShape_size = 0; - ShapeSet(outShape, &outShape_size, dy->shape_, dy->shape_size_); + ArithmeticParameter *param = (ArithmeticParameter *)parameter; - param->ndim_ = outShape_size; - param->x1_shape_size_ = param->ndim_; - param->x2_shape_size_ = param->ndim_; - param->dy_shape_size_ = param->ndim_; - int fill_dim_num0 = outShape_size - in_shape0_size; - int fill_dim_num1 = outShape_size - in_shape1_size; + param->ndim_ = dy->shape_size_; + param->in_elements_num0_ = param->ndim_; + param->in_elements_num1_ = param->ndim_; + param->out_elements_num_ = param->ndim_; + int fillDimNum0 = dy->shape_size_ - x1->shape_size_; + int fillDimNum1 = dy->shape_size_ - x2->shape_size_; int j0 = 0; int j1 = 0; - for (unsigned int i = 0; i < outShape_size; i++) { - param->x1_shape_[i] = (i < fill_dim_num0) ? 1 : in_shape0[j0++]; - param->x2_shape_[i] = (i < fill_dim_num1) ? 1 : in_shape1[j1++]; - param->dy_shape_[i] = outShape[i]; + for (unsigned int i = 0; i < dy->shape_size_; i++) { + param->in_shape0_[i] = (i < fillDimNum0) ? 1 : x1->shape_[j0++]; + param->in_shape1_[i] = (i < fillDimNum1) ? 1 : x2->shape_[j1++]; + param->out_shape_[i] = dy->shape_[i]; } SetShapeTensor(dx1, x1); SetShapeTensor(dx2, x2); - dx1->data_type_ = dy->data_type_; - dx2->data_type_ = dy->data_type_; + SetDataTypeFormat(dx1, dy); + SetDataTypeFormat(dx2, dy); return NNACL_OK; } diff --git a/mindspore/lite/nnacl/infer/arithmetic_grad_infer.c b/mindspore/lite/nnacl/infer/arithmetic_grad_infer.c index ff5d7f334d..11adf06b6b 100644 --- a/mindspore/lite/nnacl/infer/arithmetic_grad_infer.c +++ b/mindspore/lite/nnacl/infer/arithmetic_grad_infer.c @@ -15,6 +15,7 @@ */ #include "nnacl/infer/arithmetic_grad_infer.h" +#include "nnacl/arithmetic.h" /* * the Arithmetic Grad op include AddGrad, SubGrad, MulGrad, DivGrad, MaximumGrad, MinimumGrad @@ -38,8 +39,6 @@ int ArithmeticGradInferShape(const TensorC *const *inputs, size_t inputs_size, T TensorC *dx1 = outputs[0]; TensorC *dx2 = outputs[1]; - ArithmeticGradParameter *param = (ArithmeticGradParameter *)parameter; - int in_shape0[MAX_SHAPE_SIZE]; size_t in_shape0_size = 0; ShapeSet(in_shape0, &in_shape0_size, x1->shape_, x1->shape_size_); @@ -50,45 +49,47 @@ int ArithmeticGradInferShape(const TensorC *const *inputs, size_t inputs_size, T size_t out_shape_size = 0; ShapeSet(out_shape, &out_shape_size, dy->shape_, dy->shape_size_); + ArithmeticParameter *param = (ArithmeticParameter *)parameter; + if (GetElementNum(dx1) < GetElementNum(dx2)) { param->ndim_ = in_shape1_size; - param->x1_shape_size_ = param->ndim_; - param->x2_shape_size_ = param->ndim_; - param->dy_shape_size_ = param->ndim_; + param->in_elements_num0_ = param->ndim_; + param->in_elements_num1_ = param->ndim_; + param->out_elements_num_ = param->ndim_; int fill_dim_num = in_shape1_size - in_shape0_size; // This will not work for batch! int j = 0; for (unsigned int i = 0; i < in_shape1_size; i++) { if (i < fill_dim_num) { - param->x2_shape_[i] = 1; + param->in_shape1_[i] = 1; } else { - param->x2_shape_[i] = in_shape0[j++]; + param->in_shape1_[i] = in_shape0[j++]; } - param->x1_shape_[i] = in_shape1[i]; - param->dy_shape_[i] = out_shape[i]; + param->in_shape0_[i] = in_shape1[i]; + param->out_shape_[i] = out_shape[i]; } } else if (GetElementNum(dx2) < GetElementNum(dx1)) { param->ndim_ = in_shape0_size; - param->x1_shape_size_ = param->ndim_; - param->x2_shape_size_ = param->ndim_; - param->dy_shape_size_ = param->ndim_; + param->in_elements_num0_ = param->ndim_; + param->in_elements_num1_ = param->ndim_; + param->out_elements_num_ = param->ndim_; param->broadcasting_ = true; int j = 0; int fill_dim_num = in_shape0_size - in_shape1_size; for (unsigned int i = 0; i < in_shape0_size; i++) { if (i < fill_dim_num) { - param->x2_shape_[i] = 1; + param->in_shape1_[i] = 1; } else { - param->x2_shape_[i] = in_shape1[j++]; + param->in_shape1_[i] = in_shape1[j++]; } - param->x1_shape_[i] = in_shape0[i]; - param->dy_shape_[i] = out_shape[i]; + param->in_shape0_[i] = in_shape0[i]; + param->out_shape_[i] = out_shape[i]; } } else { param->broadcasting_ = false; for (unsigned int i = 0; i < in_shape0_size; i++) { - param->x2_shape_[i] = in_shape1[i]; - param->x1_shape_[i] = in_shape0[i]; - param->dy_shape_[i] = out_shape[i]; + param->in_shape1_[i] = in_shape1[i]; + param->in_shape0_[i] = in_shape0[i]; + param->out_shape_[i] = out_shape[i]; } } diff --git a/mindspore/lite/nnacl/infer/arithmetic_grad_infer.h b/mindspore/lite/nnacl/infer/arithmetic_grad_infer.h index 04323116ae..bdb1dbfbf1 100644 --- a/mindspore/lite/nnacl/infer/arithmetic_grad_infer.h +++ b/mindspore/lite/nnacl/infer/arithmetic_grad_infer.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_LITE_NNACL_ARITHMETIC_GRAD_INFER_H -#define MINDSPORE_LITE_NNACL_ARITHMETIC_GRAD_INFER_H +#ifndef MINDSPORE_LITE_NNACL_INFER_ARITHMETIC_GRAD_INFER_H_ +#define MINDSPORE_LITE_NNACL_INFER_ARITHMETIC_GRAD_INFER_H_ #include "nnacl/infer/common_infer.h" @@ -22,24 +22,10 @@ extern "C" { #endif -typedef struct ArithmeticGradParameter { - OpParameter op_parameter_; - int type_; - bool broadcasting_; // default false - int ndim_; - // std::vector dy_shape_; - int dy_shape_[MAX_SHAPE_SIZE]; - size_t dy_shape_size_; - int x1_shape_[MAX_SHAPE_SIZE]; - size_t x1_shape_size_; - int x2_shape_[MAX_SHAPE_SIZE]; - size_t x2_shape_size_; -} ArithmeticGradParameter; - int ArithmeticGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, OpParameter *parameter); #ifdef __cplusplus } #endif -#endif // MINDSPORE_LITE_NNACL_ARITHMETIC_GRAD_INFER_H +#endif // MINDSPORE_LITE_NNACL_INFER_ARITHMETIC_GRAD_INFER_H_ diff --git a/mindspore/lite/nnacl/infer/flatten_grad_infer.c b/mindspore/lite/nnacl/infer/flatten_grad_infer.c index 117fdea4ae..ddbd1f5858 100644 --- a/mindspore/lite/nnacl/infer/flatten_grad_infer.c +++ b/mindspore/lite/nnacl/infer/flatten_grad_infer.c @@ -19,7 +19,7 @@ int FlattenGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, OpParameter *parameter) { #ifdef Debug - int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1); + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 2, 1); if (check_ret != NNACL_OK) { return check_ret; } @@ -33,13 +33,7 @@ int FlattenGradInferShape(const TensorC *const *inputs, size_t inputs_size, Tens return NNACL_INFER_INVALID; } - int output_shape[2]; - size_t output_shape_size = 2; - output_shape[0] = input->shape_[0]; - output_shape[1] = 1; - for (size_t i = 1; i < input->shape_size_; i++) { - output_shape[1] *= input->shape_[i]; - } - SetShapeArray(output, output_shape, output_shape_size); + int output_shape_size = inputs[1]->shape_[0]; + SetShapeArray(output, (int *)(inputs[1]->data_), output_shape_size); return NNACL_OK; } diff --git a/mindspore/lite/nnacl/infer/layer_norm_grad_infer.c b/mindspore/lite/nnacl/infer/layer_norm_grad_infer.c new file mode 100644 index 0000000000..7e0524ce97 --- /dev/null +++ b/mindspore/lite/nnacl/infer/layer_norm_grad_infer.c @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "nnacl/infer/layer_norm_grad_infer.h" +#include "nnacl/infer/common_infer.h" +#include "nnacl/fp32_grad/layernormgrad_parameter.h" + +int LayerNormGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 5, 3); + if (check_ret != NNACL_OK) { + return check_ret; + } + LayerNormGradParameter *param = (LayerNormGradParameter *)parameter; + const TensorC *input_x = inputs[0]; + TensorC *output_dx = outputs[0]; + TensorC *output_dg = outputs[1]; + TensorC *output_db = outputs[2]; + SetDataTypeFormat(output_dx, input_x); + SetDataTypeFormat(output_dg, input_x); + SetDataTypeFormat(output_db, input_x); + SetShapeTensor(output_dx, input_x); + int begin_params_axis = param->begin_params_axis_; + if (param->begin_params_axis_ < 0) { + begin_params_axis += input_x->shape_size_; + } + int size = 0; + for (int i = begin_params_axis; i < input_x->shape_size_; i++) { + output_dg->shape_[size] = input_x->shape_[i]; + output_db->shape_[size] = input_x->shape_[i]; + size++; + } + output_db->shape_size_ = size; + output_dg->shape_size_ = size; + return NNACL_OK; +} diff --git a/mindspore/lite/nnacl/infer/layer_norm_grad_infer.h b/mindspore/lite/nnacl/infer/layer_norm_grad_infer.h new file mode 100644 index 0000000000..0e61a1c86c --- /dev/null +++ b/mindspore/lite/nnacl/infer/layer_norm_grad_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_NNACL_INFER_LAYER_NORM_GRAD_INFER_H_ +#define MINDSPORE_LITE_NNACL_INFER_LAYER_NORM_GRAD_INFER_H_ + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LayerNormGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_LITE_NNACL_INFER_LAYER_NORM_GRAD_INFER_H_ diff --git a/mindspore/lite/nnacl/infer/layer_norm_infer.c b/mindspore/lite/nnacl/infer/layer_norm_infer.c index 92f2e15b69..ee6adbdd33 100644 --- a/mindspore/lite/nnacl/infer/layer_norm_infer.c +++ b/mindspore/lite/nnacl/infer/layer_norm_infer.c @@ -19,7 +19,10 @@ int LayerNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, OpParameter *parameter) { #ifdef Debug - int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 1, 3, 1); + if ((inputs_size != 1 && inputs_size != 3) || (outputs_size != 1 && outputs_size != 3)) { + return NNACL_INPUT_TENSOR_ERROR; + } + int check_ret = CheckAugmentNull(inputs, inputs_size, outputs, outputs_size, parameter); if (check_ret != NNACL_OK) { return check_ret; } @@ -28,11 +31,27 @@ int LayerNormInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor const TensorC *input = inputs[0]; TensorC *output = outputs[0]; SetDataTypeFormat(output, input); + LayerNormParameter *param = (LayerNormParameter *)parameter; if (!param->op_parameter_.infer_flag_) { return NNACL_INFER_INVALID; } - SetShapeTensor(output, input); + // take care of other outputs + if (outputs_size == 3) { + TensorC *output_mean = outputs[1]; + TensorC *output_var = outputs[2]; + SetDataTypeFormat(output_mean, input); + SetDataTypeFormat(output_var, input); + int size = 0; + for (int i = param->begin_norm_axis_; i < input->shape_size_; i++) { + output_mean->shape_[size] = input->shape_[i]; + output_var->shape_[size] = input->shape_[i]; + size++; + } + output_mean->shape_size_ = size; + output_var->shape_size_ = size; + } + return NNACL_OK; } diff --git a/mindspore/lite/nnacl/infer/maximum_grad_infer.c b/mindspore/lite/nnacl/infer/maximum_grad_infer.c index 594f7137d3..53c6793d5f 100644 --- a/mindspore/lite/nnacl/infer/maximum_grad_infer.c +++ b/mindspore/lite/nnacl/infer/maximum_grad_infer.c @@ -15,6 +15,7 @@ */ #include "nnacl/infer/maximum_grad_infer.h" +#include "nnacl/arithmetic.h" int MaximumGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, OpParameter *parameter) { @@ -35,19 +36,20 @@ int MaximumGradInferShape(const TensorC *const *inputs, size_t inputs_size, Tens return NNACL_INFER_INVALID; } - MaximumGradParameter *param = (MaximumGradParameter *)parameter; + ArithmeticParameter *param = (ArithmeticParameter *)parameter; + param->ndim_ = dy->shape_size_; - param->x1_shape_size_ = param->ndim_; - param->x2_shape_size_ = param->ndim_; - param->dy_shape_size_ = param->ndim_; + param->in_elements_num0_ = param->ndim_; + param->in_elements_num1_ = param->ndim_; + param->out_elements_num_ = param->ndim_; int fillDimNum0 = dy->shape_size_ - x1->shape_size_; int fillDimNum1 = dy->shape_size_ - x2->shape_size_; int j0 = 0; int j1 = 0; for (unsigned int i = 0; i < dy->shape_size_; i++) { - param->x1_shape_[i] = (i < fillDimNum0) ? 1 : x1->shape_[j0++]; - param->x2_shape_[i] = (i < fillDimNum1) ? 1 : x2->shape_[j1++]; - param->dy_shape_[i] = dy->shape_[i]; + param->in_shape0_[i] = (i < fillDimNum0) ? 1 : x1->shape_[j0++]; + param->in_shape1_[i] = (i < fillDimNum1) ? 1 : x2->shape_[j1++]; + param->out_shape_[i] = dy->shape_[i]; } SetShapeTensor(dx1, x1); diff --git a/mindspore/lite/nnacl/infer/maximum_grad_infer.h b/mindspore/lite/nnacl/infer/maximum_grad_infer.h index e76c5e9350..8fdf21d262 100644 --- a/mindspore/lite/nnacl/infer/maximum_grad_infer.h +++ b/mindspore/lite/nnacl/infer/maximum_grad_infer.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_LITE_NNACL_MAXIMUM_GRAD_INFER_H -#define MINDSPORE_LITE_NNACL_MAXIMUM_GRAD_INFER_H +#ifndef MINDSPORE_LITE_NNACL_INFER_MAXIMUM_GRAD_INFER_H_ +#define MINDSPORE_LITE_NNACL_INFER_MAXIMUM_GRAD_INFER_H_ #include "nnacl/infer/common_infer.h" @@ -22,21 +22,10 @@ extern "C" { #endif -typedef struct MaximumGradParameter { - OpParameter op_parameter_; - int ndim_; - int x1_shape_[MAX_SHAPE_SIZE]; - size_t x1_shape_size_; - int x2_shape_[MAX_SHAPE_SIZE]; - size_t x2_shape_size_; - int dy_shape_[MAX_SHAPE_SIZE]; - size_t dy_shape_size_; -} MaximumGradParameter; - int MaximumGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, OpParameter *parameter); #ifdef __cplusplus } #endif -#endif // MINDSPORE_LITE_NNACL_MAXIMUM_GRAD_INFER_H +#endif // MINDSPORE_LITE_NNACL_INFER_MAXIMUM_GRAD_INFER_H_ diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 28c8926d29..5183454ecc 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -201,6 +201,10 @@ union PrimitiveType { LinSpace, UniformReal, AbsGrad, + RsqrtGrad, + SqrtGrad, + LayerNormGrad, + ResizeGrad, } table Abs { @@ -1066,3 +1070,20 @@ table UniformReal { table AbsGrad { } + +table RsqrtGrad { +} + +table SqrtGrad { +} + +table LayerNormGrad { + begin_norm_axis: long; + begin_params_axis: long; +} + +table ResizeGrad { + method: ResizeMethod; + align_corners: bool; +} + diff --git a/mindspore/lite/src/ops/ops_def.cc b/mindspore/lite/src/ops/ops_def.cc index bfe9b33940..0ecc68d7e6 100644 --- a/mindspore/lite/src/ops/ops_def.cc +++ b/mindspore/lite/src/ops/ops_def.cc @@ -200,6 +200,10 @@ OP_TYPE(IsFinite) OP_TYPE(LinSpace) OP_TYPE(UniformReal) OP_TYPE(AbsGrad) +OP_TYPE(RsqrtGrad) +OP_TYPE(SqrtGrad) +OP_TYPE(LayerNormGrad) +OP_TYPE(ResizeGrad) OP_TYPE_DEF_END(PrimitiveType) OP_SCHEMA_DEF(Abs) @@ -1065,3 +1069,19 @@ OP_SCHEMA_DEF_END(UniformReal) OP_SCHEMA_DEF(AbsGrad) OP_SCHEMA_DEF_END(AbsGrad) + +OP_SCHEMA_DEF(RsqrtGrad) +OP_SCHEMA_DEF_END(RsqrtGrad) + +OP_SCHEMA_DEF(SqrtGrad) +OP_SCHEMA_DEF_END(SqrtGrad) + +OP_SCHEMA_DEF(LayerNormGrad) +OP_ATTR(begin_norm_axis, long) +OP_ATTR(begin_params_axis, long) +OP_SCHEMA_DEF_END(LayerNormGrad) + +OP_SCHEMA_DEF(ResizeGrad) +OP_ATTR_ENUM(method, ResizeMethod) +OP_ATTR(align_corners, bool) +OP_SCHEMA_DEF_END(ResizeGrad) diff --git a/mindspore/lite/src/ops/ops_func_declare.h b/mindspore/lite/src/ops/ops_func_declare.h index b66d417f22..291c3df6f0 100644 --- a/mindspore/lite/src/ops/ops_func_declare.h +++ b/mindspore/lite/src/ops/ops_func_declare.h @@ -188,6 +188,7 @@ #include "ops/grad/dropout_grad.h" #include "ops/grad/flatten_grad.h" #include "ops/grad/group_conv2d_grad_input.h" +#include "ops/grad/layer_norm_grad.h" #include "ops/grad/log_grad.h" #include "ops/grad/max_pool_grad.h" #include "ops/grad/maximum_grad.h" @@ -196,8 +197,11 @@ #include "ops/grad/neg_grad.h" #include "ops/grad/pooling_grad.h" #include "ops/grad/power_grad.h" +#include "ops/grad/resize_grad.h" +#include "ops/grad/rsqrt_grad.h" #include "ops/grad/sigmoid_cross_entropy_with_logits_grad.h" #include "ops/grad/smooth_l1_loss_grad.h" +#include "ops/grad/sqrt_grad.h" #include "ops/grad/sub_grad.h" #include "ops/fusion/activation.h" #include "ops/fusion/add_fusion.h" @@ -449,5 +453,9 @@ FUNC_MSOP2SCHEMAOP_DECLARE(IsFinite); FUNC_MSOP2SCHEMAOP_DECLARE(LinSpace); FUNC_MSOP2SCHEMAOP_DECLARE(UniformReal); FUNC_MSOP2SCHEMAOP_DECLARE(AbsGrad); +FUNC_MSOP2SCHEMAOP_DECLARE(RsqrtGrad); +FUNC_MSOP2SCHEMAOP_DECLARE(SqrtGrad); +FUNC_MSOP2SCHEMAOP_DECLARE(LayerNormGrad); +FUNC_MSOP2SCHEMAOP_DECLARE(ResizeGrad); #endif #endif // MINDSPORE_LITE_SRC_OPS_OPS_FUNC_DECLARE_H_ diff --git a/mindspore/lite/src/ops/ops_utils.cc b/mindspore/lite/src/ops/ops_utils.cc index 0cbfa7e7db..ee60fbb09e 100644 --- a/mindspore/lite/src/ops/ops_utils.cc +++ b/mindspore/lite/src/ops/ops_utils.cc @@ -48,6 +48,10 @@ schema::PrimitiveT *AbsPrimitiveCreator(const AnfNodePtr &node) { auto ms_primc = GetValueNode>(node); return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; } +schema::PrimitiveT *AbsGradPrimitiveCreator(const AnfNodePtr &node) { + auto ms_primc = GetValueNode>(node); + return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; +} schema::PrimitiveT *ActivationPrimitiveCreator(const AnfNodePtr &node) { auto ms_primc = GetValueNode>(node); return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; @@ -336,6 +340,10 @@ schema::PrimitiveT *LayerNormFusionPrimitiveCreator(const AnfNodePtr &node) { auto ms_primc = GetValueNode>(node); return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; } +schema::PrimitiveT *LayerNormGradPrimitiveCreator(const AnfNodePtr &node) { + auto ms_primc = GetValueNode>(node); + return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; +} schema::PrimitiveT *LeakyReluPrimitiveCreator(const AnfNodePtr &node) { auto ms_primc = GetValueNode>(node); return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; @@ -516,6 +524,10 @@ schema::PrimitiveT *ResizePrimitiveCreator(const AnfNodePtr &node) { auto ms_primc = GetValueNode>(node); return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; } +schema::PrimitiveT *ResizeGradPrimitiveCreator(const AnfNodePtr &node) { + auto ms_primc = GetValueNode>(node); + return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; +} schema::PrimitiveT *ReverseV2PrimitiveCreator(const AnfNodePtr &node) { auto ms_primc = GetValueNode>(node); return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; @@ -540,6 +552,10 @@ schema::PrimitiveT *RsqrtPrimitiveCreator(const AnfNodePtr &node) { auto ms_primc = GetValueNode>(node); return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; } +schema::PrimitiveT *RsqrtGradPrimitiveCreator(const AnfNodePtr &node) { + auto ms_primc = GetValueNode>(node); + return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; +} schema::PrimitiveT *ScaleFusionPrimitiveCreator(const AnfNodePtr &node) { auto ms_primc = GetValueNode>(node); return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; @@ -628,6 +644,10 @@ schema::PrimitiveT *SqrtPrimitiveCreator(const AnfNodePtr &node) { auto ms_primc = GetValueNode>(node); return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; } +schema::PrimitiveT *SqrtGradPrimitiveCreator(const AnfNodePtr &node) { + auto ms_primc = GetValueNode>(node); + return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; +} schema::PrimitiveT *SquarePrimitiveCreator(const AnfNodePtr &node) { auto ms_primc = GetValueNode>(node); return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; @@ -648,6 +668,12 @@ schema::PrimitiveT *StridedSlicePrimitiveCreator(const AnfNodePtr &node) { auto ms_primc = GetValueNode>(node); return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; } + +schema::PrimitiveT *StridedSliceGradPrimitiveCreator(const AnfNodePtr &node) { + auto ms_primc = GetValueNode>(node); + return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; +} + schema::PrimitiveT *SubFusionPrimitiveCreator(const AnfNodePtr &node) { auto ms_primc = GetValueNode>(node); return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; @@ -718,6 +744,7 @@ schema::PrimitiveT *ZerosLikePrimitiveCreator(const AnfNodePtr &node) { } RegistryMSOps g_absPrimitiveCreatorRegistry("Abs", AbsPrimitiveCreator); +RegistryMSOps g_absGradPrimitiveCreatorRegistry("AbsGrad", AbsGradPrimitiveCreator); RegistryMSOps g_activationPrimitiveCreatorRegistry("Activation", ActivationPrimitiveCreator); RegistryMSOps g_activationGradPrimitiveCreatorRegistry("ActivationGrad", ActivationGradPrimitiveCreator); RegistryMSOps g_reluGradPrimitiveCreatorRegistry("ReluGrad", ActivationGradPrimitiveCreator); // ? @@ -741,6 +768,8 @@ RegistryMSOps g_audioSpectrogramPrimitiveCreatorRegistry("AudioSpectrogram", Aud RegistryMSOps g_avgPoolPrimitiveCreatorRegistry("AvgPool", AvgPoolFusionPrimitiveCreator); RegistryMSOps g_avgPoolFusionPrimitiveCreatorRegistry("AvgPoolFusion", AvgPoolFusionPrimitiveCreator); RegistryMSOps g_avgPoolGradPrimitiveCreatorRegistry("AvgPoolGrad", AvgPoolGradPrimitiveCreator); +RegistryMSOps g_avgPoolGradGpuPrimitiveCreatorRegistry("AvgPoolGradGpu", AvgPoolGradPrimitiveCreator); +RegistryMSOps g_avgPoolGradCpuPrimitiveCreatorRegistry("AvgPoolGradCpu", AvgPoolGradPrimitiveCreator); RegistryMSOps g_batchNormPrimitiveCreatorRegistry("BatchNorm", BatchNormPrimitiveCreator); RegistryMSOps g_batchToSpacePrimitiveCreatorRegistry("BatchToSpace", BatchToSpacePrimitiveCreator); RegistryMSOps g_batchToSpaceNDPrimitiveCreatorRegistry("BatchToSpaceND", BatchToSpaceNDPrimitiveCreator); @@ -782,6 +811,7 @@ RegistryMSOps g_dropoutPrimitiveCreatorRegistry("Dropout", DropoutPrimitiveCreat RegistryMSOps g_dropoutGradPrimitiveCreatorRegistry("DropoutGrad", DropoutGradPrimitiveCreator); RegistryMSOps g_eltwisePrimitiveCreatorRegistry("Eltwise", EltwisePrimitiveCreator); RegistryMSOps g_eluPrimitiveCreatorRegistry("Elu", EluPrimitiveCreator); +RegistryMSOps g_eluGradPrimitiveCreatorRegistry("EluGrad", ActivationGradPrimitiveCreator); RegistryMSOps g_equalPrimitiveCreatorRegistry("Equal", EqualPrimitiveCreator); RegistryMSOps g_embeddingLookupFusionPrimitiveCreatorRegistry("EmbeddingLookupFusion", EmbeddingLookupFusionPrimitiveCreator); @@ -800,6 +830,7 @@ RegistryMSOps g_fullConnectionPrimitiveCreatorRegistry("FullConnection", FullCon RegistryMSOps g_fusedBatchNormPrimitiveCreatorRegistry("FusedBatchNorm", FusedBatchNormPrimitiveCreator); RegistryMSOps g_gatherPrimitiveCreatorRegistry("Gather", GatherPrimitiveCreator); RegistryMSOps g_gatherNdPrimitiveCreatorRegistry("GatherNd", GatherNdPrimitiveCreator); +RegistryMSOps g_geluGradPrimitiveCreatorRegistry("GeluGrad", ActivationGradPrimitiveCreator); RegistryMSOps g_greaterPrimitiveCreatorRegistry("Greater", GreaterPrimitiveCreator); RegistryMSOps g_greaterEqualPrimitiveCreatorRegistry("GreaterEqual", GreaterEqualPrimitiveCreator); RegistryMSOps g_gRUPrimitiveCreatorRegistry("GRU", GRUPrimitiveCreator); @@ -808,6 +839,7 @@ RegistryMSOps g_instanceNormPrimitiveCreatorRegistry("InstanceNorm", InstanceNor RegistryMSOps g_invertPermutationPrimitiveCreatorRegistry("InvertPermutation", InvertPermutationPrimitiveCreator); RegistryMSOps g_layerNormPrimitiveCreatorRegistry("LayerNorm", LayerNormFusionPrimitiveCreator); RegistryMSOps g_layerNormFusionPrimitiveCreatorRegistry("LayerNormFusion", LayerNormFusionPrimitiveCreator); +RegistryMSOps g_layerNormGradPrimitiveCreatorRegistry("LayerNormGrad", LayerNormGradPrimitiveCreator); RegistryMSOps g_leakyReluPrimitiveCreatorRegistry("LeakyRelu", LeakyReluPrimitiveCreator); RegistryMSOps g_lessPrimitiveCreatorRegistry("Less", LessPrimitiveCreator); RegistryMSOps g_lessEqualPrimitiveCreatorRegistry("LessEqual", LessEqualPrimitiveCreator); @@ -857,12 +889,14 @@ RegistryMSOps g_reducePrimitiveCreatorRegistry("Reduce", ReduceFusionPrimitiveCr RegistryMSOps g_reduceFusionPrimitiveCreatorRegistry("ReduceFusion", ReduceFusionPrimitiveCreator); RegistryMSOps g_reshapePrimitiveCreatorRegistry("Reshape", ReshapePrimitiveCreator); RegistryMSOps g_resizePrimitiveCreatorRegistry("Resize", ResizePrimitiveCreator); +RegistryMSOps g_resizeGradPrimitiveCreatorRegistry("ResizeGrad", ResizeGradPrimitiveCreator); RegistryMSOps g_reverseV2PrimitiveCreatorRegistry("ReverseV2", ReverseV2PrimitiveCreator); RegistryMSOps g_reverseSequencePrimitiveCreatorRegistry("ReverseSequence", ReverseSequencePrimitiveCreator); RegistryMSOps g_rfftPrimitiveCreatorRegistry("Rfft", RfftPrimitiveCreator); RegistryMSOps g_rOIPoolingPrimitiveCreatorRegistry("ROIPooling", ROIPoolingPrimitiveCreator); RegistryMSOps g_roundPrimitiveCreatorRegistry("Round", RoundPrimitiveCreator); RegistryMSOps g_rsqrtPrimitiveCreatorRegistry("Rsqrt", RsqrtPrimitiveCreator); +RegistryMSOps g_rsqrtGradPrimitiveCreatorRegistry("RsqrtGrad", RsqrtGradPrimitiveCreator); RegistryMSOps g_quantDTypeCastPrimitiveCreatorRegistry("QuantDTypeCast", QuantDTypeCastPrimitiveCreator); RegistryMSOps g_scalePrimitiveCreatorRegistry("Scale", ScaleFusionPrimitiveCreator); RegistryMSOps g_scaleFusionPrimitiveCreatorRegistry("ScaleFusion", ScaleFusionPrimitiveCreator); @@ -891,11 +925,13 @@ RegistryMSOps g_sparseSoftmaxCrossEntropyWithLogitsPrimitiveCreatorRegistry( RegistryMSOps g_sparseToDensePrimitiveCreatorRegistry("SparseToDense", SparseToDensePrimitiveCreator); RegistryMSOps g_splitPrimitiveCreatorRegistry("Split", SplitPrimitiveCreator); RegistryMSOps g_sqrtPrimitiveCreatorRegistry("Sqrt", SqrtPrimitiveCreator); +RegistryMSOps g_sqrtGradPrimitiveCreatorRegistry("SqrtGrad", SqrtGradPrimitiveCreator); RegistryMSOps g_squeezePrimitiveCreatorRegistry("Squeeze", SqueezePrimitiveCreator); RegistryMSOps g_squarePrimitiveCreatorRegistry("Square", SquarePrimitiveCreator); RegistryMSOps g_squaredDifferencePrimitiveCreatorRegistry("SquaredDifference", SquaredDifferencePrimitiveCreator); RegistryMSOps g_stackPrimitiveCreatorRegistry("Stack", StackPrimitiveCreator); RegistryMSOps g_stridedSlicePrimitiveCreatorRegistry("StridedSlice", StridedSlicePrimitiveCreator); +RegistryMSOps g_stridedSliceGradPrimitiveCreatorRegistry("StridedSliceGrad", StridedSliceGradPrimitiveCreator); RegistryMSOps g_subPrimitiveCreatorRegistry("Sub", SubFusionPrimitiveCreator); RegistryMSOps g_subFusionPrimitiveCreatorRegistry("SubFusion", SubFusionPrimitiveCreator); RegistryMSOps g_subGradPrimitiveCreatorRegistry("SubGrad", SubGradPrimitiveCreator); diff --git a/mindspore/lite/src/ops/populate/layer_norm_grad_populate.cc b/mindspore/lite/src/ops/populate/layer_norm_grad_populate.cc new file mode 100644 index 0000000000..5ea99618b8 --- /dev/null +++ b/mindspore/lite/src/ops/populate/layer_norm_grad_populate.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp32_grad/layernormgrad_parameter.h" +#include "src/ops/populate/populate_register.h" + +namespace mindspore { +namespace lite { +OpParameter *PopulateLayerNormGradParameter(const void *prim) { + auto layer_norm_grad_parameter = reinterpret_cast(malloc(sizeof(LayerNormGradParameter))); + if (layer_norm_grad_parameter == nullptr) { + MS_LOG(ERROR) << "malloc LayerNormParameter failed."; + return nullptr; + } + memset(layer_norm_grad_parameter, 0, sizeof(LayerNormGradParameter)); + auto *primitive = static_cast(prim); + layer_norm_grad_parameter->op_parameter_.type_ = primitive->value_type(); + auto param = primitive->value_as_LayerNormGrad(); + layer_norm_grad_parameter->begin_norm_axis_ = param->begin_norm_axis(); + layer_norm_grad_parameter->begin_params_axis_ = param->begin_params_axis(); + return reinterpret_cast(layer_norm_grad_parameter); +} + +Registry g_layerNormGradParameterRegistry(schema::PrimitiveType_LayerNormGrad, PopulateLayerNormGradParameter, + SCHEMA_CUR); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/runtime/infer_manager.cc b/mindspore/lite/src/runtime/infer_manager.cc index d29b326906..dd3cebb0c6 100644 --- a/mindspore/lite/src/runtime/infer_manager.cc +++ b/mindspore/lite/src/runtime/infer_manager.cc @@ -63,6 +63,7 @@ #include "nnacl/infer/group_conv2d_grad_input_infer.h" #include "nnacl/infer/hashtable_lookup_infer.h" #include "nnacl/infer/layer_norm_infer.h" +#include "nnacl/infer/layer_norm_grad_infer.h" #include "nnacl/infer/lsh_projection_infer.h" #include "nnacl/infer/lstm_infer.h" #include "nnacl/infer/matmul_infer.h" @@ -214,9 +215,9 @@ static RegistryInferShape g_Deconv2dInferShape(mindspore::schema::PrimitiveType_ static RegistryInferShape g_SquaredDifferenceInferShape(mindspore::schema::PrimitiveType_SquaredDifference, ArithmeticInferShape); static RegistryInferShape g_AddInferShape(mindspore::schema::PrimitiveType_AddFusion, ArithmeticInferShape); -static RegistryInferShape g_AddSubInferShape(mindspore::schema::PrimitiveType_AddGrad, AddSubGradInferShape); +static RegistryInferShape g_AddSubInferShape(mindspore::schema::PrimitiveType_AddGrad, MaximumGradInferShape); static RegistryInferShape g_SubInferShape(mindspore::schema::PrimitiveType_SubFusion, ArithmeticInferShape); -static RegistryInferShape g_SubGradInferShape(mindspore::schema::PrimitiveType_SubGrad, AddSubGradInferShape); +static RegistryInferShape g_SubGradInferShape(mindspore::schema::PrimitiveType_SubGrad, MaximumGradInferShape); static RegistryInferShape g_DivInferShape(mindspore::schema::PrimitiveType_DivFusion, ArithmeticInferShape); static RegistryInferShape g_DivGradInferShape(mindspore::schema::PrimitiveType_DivGrad, ArithmeticGradInferShape); static RegistryInferShape g_MulInferShape(mindspore::schema::PrimitiveType_MulFusion, ArithmeticInferShape); @@ -275,6 +276,8 @@ static RegistryInferShape g_QuantDtypeCastInferShape(mindspore::schema::Primitiv static RegistryInferShape g_MfccInferShape(mindspore::schema::PrimitiveType_Mfcc, MfccInferShape); static RegistryInferShape g_AssignAddInferShape(mindspore::schema::PrimitiveType_AssignAdd, AssignAddInferShape); static RegistryInferShape g_LayerNormInferShape(mindspore::schema::PrimitiveType_LayerNormFusion, LayerNormInferShape); +static RegistryInferShape g_LayerNormGradInferShape(mindspore::schema::PrimitiveType_LayerNormGrad, + LayerNormGradInferShape); static RegistryInferShape g_UnsortedSegmentSumInferShape(mindspore::schema::PrimitiveType_UnsortedSegmentSum, UnsortedSegmentSumInferShape); static RegistryInferShape g_AddnInferShape(mindspore::schema::PrimitiveType_AddN, AddnInferShape); @@ -316,6 +319,7 @@ static RegistryInferShape g_ReverseSequenceInferShape(mindspore::schema::Primiti CommonInferShape); static RegistryInferShape g_ZerosLikeInferShape(mindspore::schema::PrimitiveType_ZerosLike, CommonInferShape); +static RegistryInferShape g_AbsGradInferShape(mindspore::schema::PrimitiveType_AbsGrad, CommonInferShape); static RegistryInferShape g_AbsInferShape(mindspore::schema::PrimitiveType_Abs, CommonInferShape); static RegistryInferShape g_ActivationGradInferShape(mindspore::schema::PrimitiveType_ActivationGrad, CommonInferShape); static RegistryInferShape g_ActivationInferShape(mindspore::schema::PrimitiveType_Activation, CommonInferShape); @@ -345,8 +349,10 @@ static RegistryInferShape g_PowerGradInferShape(mindspore::schema::PrimitiveType static RegistryInferShape g_PReLUInferShape(mindspore::schema::PrimitiveType_PReLUFusion, CommonInferShape); static RegistryInferShape g_ReverseInferShape(mindspore::schema::PrimitiveType_ReverseV2, CommonInferShape); static RegistryInferShape g_RoundInferShape(mindspore::schema::PrimitiveType_Round, CommonInferShape); +static RegistryInferShape g_RsqrtGradInferShape(mindspore::schema::PrimitiveType_RsqrtGrad, CommonInferShape); static RegistryInferShape g_RsqrtInferShape(mindspore::schema::PrimitiveType_Rsqrt, CommonInferShape); static RegistryInferShape g_ScaleInferShape(mindspore::schema::PrimitiveType_ScaleFusion, CommonInferShape); +static RegistryInferShape g_SqrtGradInferShape(mindspore::schema::PrimitiveType_SqrtGrad, CommonInferShape); static RegistryInferShape g_SqrtInferShape(mindspore::schema::PrimitiveType_Sqrt, CommonInferShape); static RegistryInferShape g_SquareInferShape(mindspore::schema::PrimitiveType_Square, CommonInferShape); @@ -426,7 +432,6 @@ static RegistryInferShape g_StridedSliceGradInferShape(mindspore::schema::Primit static RegistryInferShape g_IsFiniteInferShape(mindspore::schema::PrimitiveType_IsFinite, CommonInferShape); static RegistryInferShape g_LinSpaceInferShape(mindspore::schema::PrimitiveType_LinSpace, LinSpaceInferShape); static RegistryInferShape g_UniformRealInferShape(mindspore::schema::PrimitiveType_UniformReal, UniformRealInferShape); -static RegistryInferShape g_AbsGradInferShape(mindspore::schema::PrimitiveType_AbsGrad, CommonInferShape); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h index 02e321bea2..5801a58206 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_FP32_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_FP32_H_ #include #include "src/lite_kernel.h" @@ -122,4 +122,4 @@ class ArithmeticCPUKernel : public LiteKernel { }; int ArithmeticsRun(void *cdata, int task_id); } // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_ +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_FP32_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.cc index d57f531699..2da7dbe689 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.cc @@ -79,7 +79,6 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() { // init bias size_t new_bias_size = UP_ROUND(out_channel, C4NUM) * sizeof(float); - bias_data_ = malloc(new_bias_size); if (bias_data_ == nullptr) { bias_data_ = reinterpret_cast(malloc(new_bias_size)); if (bias_data_ == nullptr) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/layer_norm_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/layer_norm_fp32.cc index 318282e776..08e3f10ef3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/layer_norm_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/layer_norm_fp32.cc @@ -61,7 +61,7 @@ int LayerNormCPUKernel::ReSize() { } int LayerNormCPUKernel::DoLayerNorm(int thread_id) { - int ret = LayerNorm(src_data_, gamma_data_, beta_data_, dst_data_, param_, thread_id); + int ret = LayerNorm(src_data_, gamma_data_, beta_data_, dst_data_, param_, mean_data_, var_data_, thread_id); if (ret != RET_OK) { MS_LOG(ERROR) << "DoLayerNorm error error_code[" << ret << "]"; return ret; @@ -80,17 +80,17 @@ int LayerNormRun(void *cdata, int task_id) { } int LayerNormCPUKernel::Run() { + int ret = RET_OK; src_data_ = reinterpret_cast(in_tensors_.at(0)->data_c()); gamma_data_ = reinterpret_cast(in_tensors_.at(1)->data_c()); beta_data_ = reinterpret_cast(in_tensors_.at(2)->data_c()); dst_data_ = reinterpret_cast(out_tensors_.at(0)->data_c()); - - auto ret = ParallelLaunch(this->context_->thread_pool_, LayerNormRun, this, op_parameter_->thread_num_); - if (ret != RET_OK) { - MS_LOG(ERROR) << "LayerNormRun error error_code[" << ret << "]"; - return ret; + if (out_tensors_.size() >= 3) { + mean_data_ = reinterpret_cast(out_tensors_.at(1)->data_c()); + var_data_ = reinterpret_cast(out_tensors_.at(2)->data_c()); } - return RET_OK; + ret = ParallelLaunch(this->context_->thread_pool_, LayerNormRun, this, op_parameter_->thread_num_); + return ret; } REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LayerNormFusion, LiteKernelCreator) diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/layer_norm_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/layer_norm_fp32.h index 4bd9255890..b0c87e510b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/layer_norm_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/layer_norm_fp32.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LAYER_NORM_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LAYER_NORM_H_ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LAYER_NORM_FP32_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LAYER_NORM_FP32_H_ #include #include "src/lite_kernel.h" #include "include/context.h" @@ -43,7 +43,9 @@ class LayerNormCPUKernel : public LiteKernel { float *dst_data_ = nullptr; float *gamma_data_ = nullptr; float *beta_data_ = nullptr; + float *mean_data_ = nullptr; + float *var_data_ = nullptr; }; } // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LAYER_NORM_H_ +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LAYER_NORM_FP32_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.cc index 8c59424db2..155d71f1af 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.cc @@ -25,6 +25,8 @@ using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; +using mindspore::schema::ActivationType_ELU; +using mindspore::schema::ActivationType_GELU; using mindspore::schema::ActivationType_HSWISH; using mindspore::schema::ActivationType_LEAKY_RELU; using mindspore::schema::ActivationType_RELU; @@ -69,6 +71,10 @@ int ActivationGradCPUKernel::DoActivation(int task_id) { error_code = HSwishGrad(yt_addr + start, input_addr + start, count, output_addr + start); } else if (param_act_grad_->type_ == schema::ActivationType_HSIGMOID) { error_code = HSigmoidGrad(yt_addr + start, input_addr + start, count, output_addr + start); + } else if (param_act_grad_->type_ == schema::ActivationType_ELU) { + error_code = EluGrad(yt_addr + start, input_addr + start, count, output_addr + start, param_act_grad_->alpha_); + } else if (param_act_grad_->type_ == schema::ActivationType_GELU) { + error_code = GeluGrad(yt_addr + start, input_addr + start, count, output_addr + start); } else { MS_LOG(ERROR) << "Activation type error"; return RET_ERROR; @@ -99,27 +105,5 @@ int ActivationGradCPUKernel::Run() { return RET_OK; } -kernel::LiteKernel *CpuActivationGradFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, - OpParameter *opParameter, const lite::InnerContext *ctx, - const kernel::KernelKey &desc) { - MS_ASSERT(opParameter != nullptr); - MS_ASSERT(desc.type == schema::PrimitiveType_ActivationGrad); - auto *kernel = new (std::nothrow) ActivationGradCPUKernel(opParameter, inputs, outputs, ctx); - if (kernel == nullptr) { - MS_LOG(ERROR) << "new ActivationGradCPUKernel fail!"; - free(opParameter); - return nullptr; - } - auto ret = kernel->Init(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " - << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); - delete kernel; - return nullptr; - } - return kernel; -} - -REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ActivationGrad, CpuActivationGradFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ActivationGrad, LiteKernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.cc index 2922ff498b..a02a50415c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.cc @@ -73,13 +73,13 @@ int AdamCPUKernel::Execute(int task_id) { auto beta2 = reinterpret_cast(in_tensors_.at(7)->MutableData())[0]; auto eps = reinterpret_cast(in_tensors_.at(8)->MutableData())[0]; auto gradient = reinterpret_cast(in_tensors_.at(9)->MutableData()); - size_t length = in_tensors_.at(0)->ElementsNum(); + int length = in_tensors_.at(0)->ElementsNum(); - size_t stride = UP_DIV(length, thread_count_); - size_t count = MSMIN(stride, length - stride * task_id); + int stride = UP_DIV(length, thread_count_); + int count = MSMIN(stride, length - stride * task_id); - size_t start = stride * task_id; - size_t end = start + count; + int start = stride * task_id; + int end = start + count; return DoAdam(m, v, gradient, weight, beta1, beta2, beta1_power, beta2_power, eps, learning_rate, adam_param_->use_nesterov_, start, end); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.cc index a366a3763c..bb1f938510 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.cc @@ -52,12 +52,12 @@ int ApplyMomentumCPUKernel::Execute(int task_id) { float learning_rate = lr_; auto gradient = reinterpret_cast(in_tensors_.at(3)->MutableData()); float moment = reinterpret_cast(in_tensors_.at(4)->MutableData())[0]; - size_t length = in_tensors_.at(0)->ElementsNum(); + int length = in_tensors_.at(0)->ElementsNum(); - size_t stride = UP_DIV(length, thread_count_); - size_t count = MSMIN(stride, length - stride * task_id); - size_t start = stride * task_id; - size_t end = start + count; + int stride = UP_DIV(length, thread_count_); + int count = MSMIN(stride, length - stride * task_id); + int start = stride * task_id; + int end = start + count; DoApplyMomentum(weight, accumulate, learning_rate, gradient, moment, apply_momentum_param_->use_nesterov_, start, end); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_self_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_self_grad.cc index 1eea6318bf..19ee0e0f5b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_self_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_self_grad.cc @@ -28,6 +28,8 @@ using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_AbsGrad; using mindspore::schema::PrimitiveType_LogGrad; +using mindspore::schema::PrimitiveType_RsqrtGrad; +using mindspore::schema::PrimitiveType_SqrtGrad; namespace mindspore::kernel { namespace { @@ -47,6 +49,12 @@ int ArithmeticSelfGradCPUKernel::Init() { case PrimitiveType_AbsGrad: self_grad_operation_ = ElementAbsGrad; break; + case PrimitiveType_SqrtGrad: + self_grad_operation_ = ElementSqrtGrad; + break; + case PrimitiveType_RsqrtGrad: + self_grad_operation_ = ElementRsqrtGrad; + break; default: MS_LOG(ERROR) << "Unsupported type: " << type; return RET_ERROR; @@ -58,11 +66,11 @@ int ArithmeticSelfGradCPUKernel::DoArithmeticSelfGrad(int task_id) { auto dy = reinterpret_cast(in_tensors_.at(0)->MutableData()); auto in_x = reinterpret_cast(in_tensors_.at(1)->MutableData()); auto dx = reinterpret_cast(out_tensors_.at(0)->MutableData()); - size_t length = in_tensors_.at(0)->ElementsNum(); + int length = in_tensors_.at(0)->ElementsNum(); - size_t stride = UP_DIV(length, thread_count_); - size_t count = MSMIN(stride, length - stride * task_id); - size_t start = stride * task_id; + int stride = UP_DIV(length, thread_count_); + int count = MSMIN(stride, length - stride * task_id); + int start = stride * task_id; (*self_grad_operation_)(dy + start, in_x + start, dx + start, count); return RET_OK; @@ -107,4 +115,6 @@ kernel::LiteKernel *CpuArithmeticSelfGradFp32KernelCreator(const std::vector(in_tensors_.at(0)->MutableData()); auto y = reinterpret_cast(in_tensors_.at(1)->MutableData()); - size_t length = in_tensors_.at(0)->ElementsNum(); + int length = in_tensors_.at(0)->ElementsNum(); - size_t stride = UP_DIV(length, thread_count_); - size_t count = MSMIN(stride, length - stride * task_id); + int stride = UP_DIV(length, thread_count_); + int count = MSMIN(stride, length - stride * task_id); - size_t start = stride * task_id; + int start = stride * task_id; memcpy(&(x[start]), &(y[start]), count * sizeof(float)); return RET_OK; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout_grad.cc index 5c1dd4495a..9845859746 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout_grad.cc @@ -62,14 +62,13 @@ int DropoutGradCPUKernel::Execute(int task_id) { auto mask_ptr = reinterpret_cast(in_tensors_.at(1)->MutableData()); auto output_ptr = reinterpret_cast(out_tensors_.at(kOutputIndex)->MutableData()); auto length = in_tensors_.at(kInputIndex)->ElementsNum(); - int stride = UP_DIV(length, thread_count_); int count = MSMIN(stride, length - stride * task_id); - size_t start = stride * task_id; - - DropoutGrad(&(yt_ptr[start]), &(mask_ptr[start]), &(output_ptr[start]), count, scale_); - + if (count > 0) { + int start = stride * task_id; + DropoutGrad(&(yt_ptr[start]), &(mask_ptr[start]), &(output_ptr[start]), count, scale_); + } return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/elu_grad_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/elu_grad_fp32.h new file mode 100644 index 0000000000..922d0cda63 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/elu_grad_fp32.h @@ -0,0 +1,45 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ELU_GRAD_FP32_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ELU_GRAD_FP32_H_ + +#include +#include "src/lite_kernel.h" +#include "nnacl/fp32/elu_fp32.h" + +namespace mindspore { +namespace kernel { + +class EluGradCPUKernel : public LiteKernel { + public: + EluGradCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx) + : LiteKernel(parameter, inputs, outputs, ctx) {} + ~EluGradCPUKernel() = default; + + int Init() override; + int ReSize() override; + int Run() override; + int DoExcute(int task_id); + + private: + float alpha_ = 1.0; // currently MS supports only alpha = 1.0 +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ELU_GRAD_FP32_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/layernorm_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/layernorm_grad.cc new file mode 100644 index 0000000000..5470378af4 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/layernorm_grad.cc @@ -0,0 +1,109 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32_grad/layernorm_grad.h" +#include + +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "nnacl/fp32_grad/layernorm_grad.h" +#include "nnacl/fp32_grad/layernormgrad_parameter.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_LayerNormGrad; + +namespace mindspore::kernel { +int LayerNormGradCPUKernel::ReSize() { return RET_OK; } + +int LayerNormGradCPUKernel::Init() { + auto lngrad_param = reinterpret_cast(op_parameter_); + auto *input_x = in_tensors_.at(0); + std::vector x_shape = input_x->shape(); + int begin_norm_axis = lngrad_param->begin_norm_axis_; + if (begin_norm_axis < 0) { + begin_norm_axis += x_shape.size(); + } + auto begin_params_axis = lngrad_param->begin_params_axis_; + if (begin_params_axis < 0) { + begin_params_axis += x_shape.size(); + } + for (size_t i = 0; i < static_cast(begin_norm_axis); i++) { + block_num_ *= x_shape[i]; + } + for (size_t i = static_cast(begin_norm_axis); i < x_shape.size(); i++) { + block_size_ *= x_shape[i]; + } + for (size_t i = 0; i < static_cast(begin_params_axis); i++) { + param_size_ *= x_shape[i]; + } + for (size_t i = begin_params_axis; i < x_shape.size(); i++) { + param_num_ *= x_shape[i]; + } + if (block_num_ <= 0 || block_size_ <= 0) { + MS_LOG(ERROR) << "LayerNormGradCPUKernel input shape error, input shape: " << x_shape; + } + return RET_OK; +} + +int LayerNormGradCPUKernel::Execute(int task_id) { + auto input_x = in_tensors_.at(0); + auto input_dy = in_tensors_.at(1); + auto input_var = in_tensors_.at(2); + auto input_mean = in_tensors_.at(3); + auto input_gamma = in_tensors_.at(4); + auto output_dx = out_tensors_.at(0); + auto output_dg = out_tensors_.at(1); + auto output_db = out_tensors_.at(2); + + float *x = reinterpret_cast(input_x->MutableData()); + float *dy = reinterpret_cast(input_dy->MutableData()); + float *var = reinterpret_cast(input_var->MutableData()); + float *mean = reinterpret_cast(input_mean->MutableData()); + float *gamma = reinterpret_cast(input_gamma->MutableData()); + float *dx = reinterpret_cast(output_dx->MutableData()); + float *dg = reinterpret_cast(output_dg->MutableData()); + float *db = reinterpret_cast(output_db->MutableData()); + LayerNormGrad(x, dy, var, mean, gamma, param_num_, param_size_, block_num_, block_size_, dx, dg, db); + return RET_OK; +} + +int LayerNormGradRun(void *cdata, int task_id) { + MS_ASSERT(cdata != nullptr); + auto ln_kernel = reinterpret_cast(cdata); + auto error_code = ln_kernel->Execute(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "LayerNormGradRun error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int LayerNormGradCPUKernel::Run() { + int error_code = ParallelLaunch(this->context_->thread_pool_, LayerNormGradRun, this, 1); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "LayerNorm function error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LayerNormGrad, LiteKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/layernorm_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/layernorm_grad.h new file mode 100644 index 0000000000..403b55ce2b --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/layernorm_grad.h @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_LAYERNORM_GRAD_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_LAYERNORM_GRAD_H_ + +#include +#include "src/lite_kernel.h" + +namespace mindspore::kernel { + +class LayerNormGradCPUKernel : public LiteKernel { + public: + explicit LayerNormGradCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx) + : LiteKernel(parameter, inputs, outputs, ctx) {} + ~LayerNormGradCPUKernel() override {} + int Init() override; + int ReSize() override; + int Run() override; + int Execute(int task_id); + + private: + int block_num_ = 1; + int block_size_ = 1; + int param_num_ = 1; + int param_size_ = 1; +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_LAYERNORM_GRAD_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/neg_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/neg_grad.cc index c4a2e68a3c..0c1561297e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/neg_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/neg_grad.cc @@ -42,12 +42,12 @@ int NegGradCPUKernel::Init() { return RET_OK; } int NegGradCPUKernel::DoNegGrad(int task_id) { auto dy = reinterpret_cast(in_tensors_.at(0)->MutableData()); auto dx = reinterpret_cast(out_tensors_.at(0)->MutableData()); - size_t length = in_tensors_.at(0)->ElementsNum(); + int length = in_tensors_.at(0)->ElementsNum(); - size_t stride = UP_DIV(length, thread_count_); - size_t count = MSMIN(stride, length - stride * task_id); + int stride = UP_DIV(length, thread_count_); + int count = MSMIN(stride, length - stride * task_id); - size_t start = stride * task_id; + int start = stride * task_id; ElementNegative(dy + start, dx + start, count); return RET_OK; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc index fc18206f3a..e43fabc761 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc @@ -66,21 +66,23 @@ int PoolingGradCPUKernel::Execute(int task_id) { PoolingParameter *pool_param = reinterpret_cast(op_parameter_); auto input_ptr = reinterpret_cast(in_tensors_.at(0)->MutableData()); auto output_ptr = reinterpret_cast(out_tensors_.at(0)->MutableData()); - int stride = UP_DIV(pool_param->output_batch_, thread_num_); int count = MSMIN(stride, pool_param->output_batch_ - stride * task_id); - int in_batch_size = pool_param->input_h_ * pool_param->input_w_ * pool_param->input_channel_; - int out_batch_size = pool_param->output_h_ * pool_param->output_w_ * pool_param->input_channel_; - std::fill(output_ptr + task_id * stride * in_batch_size, output_ptr + ((task_id * stride) + count) * in_batch_size, - 0.f); - if (pool_param->pool_mode_ == PoolMode_MaxPool) { - auto dy_ptr = reinterpret_cast(in_tensors_.at(2)->MutableData()); - MaxPoolingGrad(input_ptr + task_id * stride * in_batch_size, dy_ptr + task_id * stride * out_batch_size, - output_ptr + task_id * stride * in_batch_size, count, pool_param); - } else { - input_ptr = reinterpret_cast(in_tensors_.at(2)->MutableData()); - AvgPoolingGrad(input_ptr + task_id * stride * out_batch_size, output_ptr + task_id * stride * in_batch_size, count, - pool_param); + + if (count > 0) { + int in_batch_size = pool_param->input_h_ * pool_param->input_w_ * pool_param->input_channel_; + int out_batch_size = pool_param->output_h_ * pool_param->output_w_ * pool_param->input_channel_; + std::fill(output_ptr + task_id * stride * in_batch_size, output_ptr + ((task_id * stride) + count) * in_batch_size, + 0.f); + if (pool_param->pool_mode_ == PoolMode_MaxPool) { + auto dy_ptr = reinterpret_cast(in_tensors_.at(2)->MutableData()); + MaxPoolingGrad(input_ptr + task_id * stride * in_batch_size, dy_ptr + task_id * stride * out_batch_size, + output_ptr + task_id * stride * in_batch_size, count, pool_param); + } else { + input_ptr = reinterpret_cast(in_tensors_.at(2)->MutableData()); + AvgPoolingGrad(input_ptr + task_id * stride * out_batch_size, output_ptr + task_id * stride * in_batch_size, + count, pool_param); + } } return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/power_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/power_grad.cc index c648381222..bfec240f67 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/power_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/power_grad.cc @@ -46,19 +46,19 @@ int PowerGradCPUKernel::Execute(int task_id) { auto x_addr = reinterpret_cast(in_tensors_.at(1)->MutableData()); auto dx_addr = reinterpret_cast(out_tensors_.at(0)->MutableData()); - size_t length = in_tensors_.at(0)->ElementsNum(); + int length = in_tensors_.at(0)->ElementsNum(); - size_t stride = UP_DIV(length, thread_count_); - size_t count = MSMIN(stride, length - stride * task_id); + int stride = UP_DIV(length, thread_count_); + int count = MSMIN(stride, length - stride * task_id); - size_t start = stride * task_id; - size_t end = start + count; + int start = stride * task_id; + int end = start + count; float exp = power_ - 1; Power(&(x_addr[start]), &exp, &(dx_addr[start]), count, scale_, shift_, true); ElementMul(&(dx_addr[start]), &(dy_addr[start]), &(dx_addr[start]), count); float scale = scale_ * power_; - for (size_t i = start; i < end; i++) { + for (int i = start; i < end; i++) { dx_addr[i] *= scale; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/resize_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/resize_grad.cc new file mode 100644 index 0000000000..34fd46b239 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/resize_grad.cc @@ -0,0 +1,104 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "src/runtime/kernel/arm/fp32_grad/resize_grad.h" +#include "nnacl/fp32_grad/resize_grad.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_ResizeGrad; + +namespace mindspore::kernel { + +float Scaling(size_t in_size, size_t out_size, bool align_corners) { + return (align_corners && out_size > 1) ? (in_size - 1) / static_cast(out_size - 1) + : in_size / static_cast(out_size); +} + +int ResizeGradCPUKernel::ReSize() { + auto param = reinterpret_cast(op_parameter_); + if (param == nullptr) { + MS_LOG(ERROR) << "ResizeGradCPUKernel op_parameter_ is nullptr"; + return RET_ERROR; + } + bool align_corners = param->align_corners_; + param->in_height_ = static_cast(in_tensors_.at(0)->Height()); + param->in_width_ = static_cast(in_tensors_.at(0)->Width()); + param->out_height_ = static_cast(out_tensors_.at(0)->Height()); + param->out_width_ = static_cast(out_tensors_.at(0)->Width()); + param->height_scale_ = Scaling(param->out_height_, param->in_height_, align_corners); + param->width_scale_ = Scaling(param->out_width_, param->in_width_, align_corners); + + return RET_OK; +} + +int ResizeGradCPUKernel::Init() { + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int ResizeGradCPUKernel::Execute(int task_id) { + auto in_addr = reinterpret_cast(in_tensors_.at(0)->MutableData()); + auto out_addr = reinterpret_cast(out_tensors_.at(0)->MutableData()); + auto param = reinterpret_cast(op_parameter_); + if (param == nullptr) { + MS_LOG(ERROR) << "ResizeGradCPUKernel op_parameter_ is nullptr"; + return RET_ERROR; + } + auto batch_size = in_tensors_.at(0)->Batch(); + auto channel = in_tensors_.at(0)->Channel(); + + if (param->method == static_cast(schema::ResizeMethod_NEAREST)) { + ResizeNearestNeighborGrad(in_addr, out_addr, batch_size, channel, param); + } else { + ResizeBiLinearGrad(in_addr, out_addr, batch_size, channel, param); + } + return RET_OK; +} + +int ResizeGradRun(void *cdata, int task_id) { + auto resize_grad_kernel = reinterpret_cast(cdata); + auto error_code = resize_grad_kernel->Execute(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "resize grad error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int ResizeGradCPUKernel::Run() { + auto out_addr = reinterpret_cast(out_tensors_.at(0)->MutableData()); + size_t elem_number = out_tensors_.at(0)->ElementsNum(); + std::fill(out_addr, out_addr + elem_number, 0.f); + int error_code = ParallelLaunch(this->context_->thread_pool_, ResizeGradRun, this, 1); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "ResizeGradCPUKernel function error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ResizeGrad, LiteKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/resize_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/resize_grad.h new file mode 100644 index 0000000000..14ec2a3754 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/resize_grad.h @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_RESIZE_GRAD_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_RESIZE_GRAD_H_ + +#include +#include "src/lite_kernel.h" + +namespace mindspore::kernel { +class ResizeGradCPUKernel : public LiteKernel { + public: + explicit ResizeGradCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx) + : LiteKernel(parameter, inputs, outputs, ctx) {} + ~ResizeGradCPUKernel() override = default; + int Init() override; + int ReSize() override; + int Run() override; + int ExecuteInit(int task_id); + int Execute(int task_id); +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_RESIZE_GRAD_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sgd.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sgd.cc index 46e0cac4fe..97e8bd2bba 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sgd.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sgd.cc @@ -76,13 +76,13 @@ int SgdCPUKernel::Execute(int task_id) { float learning_rate = lr_; auto gradient = reinterpret_cast(in_tensors_.at(1)->MutableData()); float moment = reinterpret_cast(in_tensors_.at(4)->MutableData())[0]; - size_t length = in_tensors_.at(0)->ElementsNum(); + int length = in_tensors_.at(0)->ElementsNum(); - size_t stride = UP_DIV(length, thread_count_); - size_t count = MSMIN(stride, length - stride * task_id); + int stride = UP_DIV(length, thread_count_); + int count = MSMIN(stride, length - stride * task_id); - size_t start = stride * task_id; - size_t end = start + count; + int start = stride * task_id; + int end = start + count; DoSgd(weight, accumulate, gradient, learning_rate, sgd_param_->dampening_, moment, sgd_param_->use_nesterov_, start, end); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/smooth_l1_loss_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/smooth_l1_loss_grad.cc index a93efea3b8..165b197be4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/smooth_l1_loss_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/smooth_l1_loss_grad.cc @@ -36,17 +36,17 @@ int SmoothL1LossGradCPUKernel::Execute(int task_id) { auto d_loss = reinterpret_cast(in_tensors_.at(2)->MutableData()); auto *out = reinterpret_cast(out_tensors_.at(0)->MutableData()); - const size_t length = in_tensors_.at(0)->ElementsNum(); + int length = in_tensors_.at(0)->ElementsNum(); - size_t stride = UP_DIV(length, thread_count_); - size_t count = MSMIN(stride, length - stride * task_id); + int stride = UP_DIV(length, thread_count_); + int count = MSMIN(stride, length - stride * task_id); - size_t start = stride * task_id; - size_t end = start + count; + int start = stride * task_id; + int end = start + count; const float beta = smooth_l1_loss_param->beta_; - for (uint64_t i = start; i < end; ++i) { + for (int i = start; i < end; ++i) { float diff = predict[i] - target[i]; if (diff > beta) { out[i] = d_loss[i]; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/unsorted_segment_sum.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/unsorted_segment_sum.cc new file mode 100644 index 0000000000..8d840966e9 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/unsorted_segment_sum.cc @@ -0,0 +1,94 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32_grad/unsorted_segment_sum.h" +#include +#include +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "nnacl/fp32_grad/unsorted_segment_sum.h" +#include "include/errorcode.h" +#include "src/runtime/runtime_api.h" + +using mindspore::kernel::KERNEL_ARCH::kCPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_UnsortedSegmentSum; + +namespace mindspore::kernel { + +int UnsortedSegmentSumCPUKernel::Init() { + if (!InferShapeDone()) { + return RET_OK; + } + auto input_shape = in_tensors_.at(0)->shape(); + auto segment_ids_shape = in_tensors_.at(1)->shape(); + auto output_shape = out_tensors_.at(0)->shape(); + for (size_t i = 0; i < input_shape.size(); ++i) { + unit_num_ *= input_shape[i]; + if (i >= segment_ids_shape.size()) { + input_dim1_ *= input_shape[i]; + } + } + output_dim0_ = output_shape[0]; + for (size_t j = 1; j < output_shape.size(); j++) { + output_dim1_ *= output_shape[j]; + } + return RET_OK; +} + +int UnsortedSegmentSumCPUKernel::ReSize() { return RET_OK; } + +int UnsortedSegmentSumRun(void *cdata, int task_id) { + MS_ASSERT(cdata != nullptr); + auto kernel = reinterpret_cast(cdata); + auto error_code = kernel->Execute(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "UnsortedSegmentSum Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int UnsortedSegmentSumCPUKernel::Run() { + int error_code = ParallelLaunch(this->context_->thread_pool_, UnsortedSegmentSumRun, this, 1); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Strided slice error error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} + +int UnsortedSegmentSumCPUKernel::Execute(int task_id) { + int ret; + auto input_tensor = in_tensors_.at(0); + auto indices_tensor = in_tensors_.at(1); + auto output_tensor = out_tensors_.at(0); + float *input = reinterpret_cast(input_tensor->data_c()); + int *indices = reinterpret_cast(indices_tensor->data_c()); + float *output = reinterpret_cast(output_tensor->MutableData()); + std::fill(output, output + output_tensor->ElementsNum(), 0.f); + ret = UnsortedSegmentSum(input, unit_num_, input_dim1_, indices, output, output_dim0_, output_dim1_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "StridedSliceGrad error error_code[" << ret << "]"; + return RET_ERROR; + } + return RET_OK; +} + +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_UnsortedSegmentSum, LiteKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/unsorted_segment_sum.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/unsorted_segment_sum.h new file mode 100644 index 0000000000..27fed5d3d5 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/unsorted_segment_sum.h @@ -0,0 +1,44 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_UNSORTED_SEGMENT_SUM_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_UNSORTED_SEGMENT_SUM_H_ + +#include +#include "src/lite_kernel.h" + +namespace mindspore::kernel { +class UnsortedSegmentSumCPUKernel : public LiteKernel { + public: + UnsortedSegmentSumCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx) + : LiteKernel(parameter, inputs, outputs, ctx) {} + ~UnsortedSegmentSumCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + int Execute(int task_id); + size_t unit_num_; + size_t input_dim1_; + size_t output_dim0_; + size_t output_dim1_; + + private: +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_UNSORTED_SEGMENT_SUM_H_ diff --git a/mindspore/lite/src/train/classification_train_accuracy_monitor.cc b/mindspore/lite/src/train/classification_train_accuracy_monitor.cc index 326a5c8e93..ec66bdae4a 100644 --- a/mindspore/lite/src/train/classification_train_accuracy_monitor.cc +++ b/mindspore/lite/src/train/classification_train_accuracy_monitor.cc @@ -18,10 +18,13 @@ #include #include #include "include/errorcode.h" +#include "src/common/log_adapter.h" #include "include/train_session.h" #include "src/common/utils.h" #include "src/train/train_utils.h" +using mindspore::WARNING; + namespace mindspore { namespace lite { diff --git a/mindspore/lite/src/train/optimizer_kernel.h b/mindspore/lite/src/train/optimizer_kernel.h index 10e0842335..21d487554d 100644 --- a/mindspore/lite/src/train/optimizer_kernel.h +++ b/mindspore/lite/src/train/optimizer_kernel.h @@ -69,23 +69,24 @@ class OptimizerKernel : public LiteKernel { std::fill(grad_sum_, grad_sum_ + elem_num, 0); } else { if (grad_sum_ != nullptr) { + OptimizerStep(); context_->allocator->Free(grad_sum_); grad_sum_ = nullptr; } } + weightUpdateMod_ = WeightUpdateMode::VIRTUAL_BATCH; return RET_OK; } int ExecuteVirtualBatch(int task_id) { auto gradient = reinterpret_cast(in_tensors_.at(grad_idx_)->MutableData()); - size_t length = in_tensors_.at(grad_idx_)->ElementsNum(); + int length = in_tensors_.at(grad_idx_)->ElementsNum(); - size_t stride = UP_DIV(length, context_->thread_num_); - size_t count = MSMIN(stride, length - stride * task_id); - size_t start = stride * task_id; - size_t end = start + count; - - for (size_t i = start; i < end; ++i) { + int stride = UP_DIV(length, context_->thread_num_); + int count = MSMIN(stride, length - stride * task_id); + int start = stride * task_id; + int end = start + count; + for (int i = start; i < end; ++i) { grad_sum_[i] += gradient[i]; } valid_grad_sum_ = true; @@ -97,7 +98,10 @@ class OptimizerKernel : public LiteKernel { return RET_OK; } - int Eval() override { return OptimizerStep(); } + int Eval() override { + OptimizerStep(); + return LiteKernel::Eval(); + } protected: float default_lr_ = 0.0f; diff --git a/mindspore/lite/src/train/train_loop.cc b/mindspore/lite/src/train/train_loop.cc index 7d55a0df83..53f4dca27d 100644 --- a/mindspore/lite/src/train/train_loop.cc +++ b/mindspore/lite/src/train/train_loop.cc @@ -22,6 +22,7 @@ #include "include/errorcode.h" #include "include/train_session.h" #include "include/iterator.h" +#include "src/common/log_adapter.h" namespace mindspore { namespace lite { @@ -167,11 +168,9 @@ int TrainLoop::LoadPartialData(std::vector inputs, dataset:: } // namespace lite -session::TrainLoop *session::TrainLoop::CreateTrainLoop(const std::string &model_filename, lite::Context *context, +session::TrainLoop *session::TrainLoop::CreateTrainLoop(session::TrainSession *train_session, lite::Context *context, int batch_size) { - auto train_session = session::TrainSession::CreateSession(model_filename, context); auto loop = new (std::nothrow) lite::TrainLoop(train_session); - return loop; } diff --git a/mindspore/lite/src/train/train_loop.h b/mindspore/lite/src/train/train_loop.h index b0c04a516b..0504ec2198 100644 --- a/mindspore/lite/src/train/train_loop.h +++ b/mindspore/lite/src/train/train_loop.h @@ -20,10 +20,10 @@ #include #include #include +#include "include/errorcode.h" #include "include/train/train_loop.h" #include "include/train/metrics.h" #include "include/train_session.h" -#include "include/errorcode.h" #include "include/datasets.h" #include "include/iterator.h" #include "src/common/log_adapter.h" diff --git a/mindspore/lite/src/train/train_populate_parameter.cc b/mindspore/lite/src/train/train_populate_parameter.cc index 6dae3fef33..de36aba915 100644 --- a/mindspore/lite/src/train/train_populate_parameter.cc +++ b/mindspore/lite/src/train/train_populate_parameter.cc @@ -28,8 +28,10 @@ #include "nnacl/fp32_grad/batch_norm.h" #include "nnacl/fp32_grad/dropout_parameter.h" #include "nnacl/fp32_grad/smooth_l1_loss.h" +#include "nnacl/fp32_grad/resize_grad.h" +namespace mindspore { +namespace kernel { -namespace mindspore::kernel { OpParameter *PopulateSmoothL1LossParameter(const void *prim) { SmoothL1LossParameter *p = reinterpret_cast(malloc(sizeof(SmoothL1LossParameter))); if (p == nullptr) { @@ -170,9 +172,20 @@ OpParameter *PopulateMaxPoolGradParameter(const void *prim) { pooling_param->pad_r_ = 0; pooling_param->stride_w_ = static_cast(value->strides()->Get(1)); pooling_param->stride_h_ = static_cast(value->strides()->Get(0)); - pooling_param->round_mode_ = RoundMode_No; pooling_param->pool_mode_ = PoolMode_MaxPool; + switch (value->pad_mode()) { + case schema::PadMode_SAME: + pooling_param->pad_mode_ = Pad_same; + break; + case schema::PadMode_VALID: + pooling_param->pad_mode_ = Pad_valid; + break; + default: + pooling_param->pad_mode_ = Pad_pad; + break; + } + return reinterpret_cast(pooling_param); } @@ -197,8 +210,30 @@ OpParameter *PopulateAvgPoolGradParameter(const void *prim) { pooling_param->stride_w_ = static_cast(value->strides()->Get(1)); pooling_param->stride_h_ = static_cast(value->strides()->Get(0)); + switch (value->pad_mode()) { + case schema::PadMode_SAME: + pooling_param->pad_mode_ = Pad_same; + break; + case schema::PadMode_VALID: + pooling_param->pad_mode_ = Pad_valid; + break; + default: + pooling_param->pad_mode_ = Pad_pad; + break; + } pooling_param->round_mode_ = RoundMode_No; pooling_param->pool_mode_ = PoolMode_AvgPool; + switch (value->pad_mode()) { + case schema::PadMode_SAME: + pooling_param->pad_mode_ = Pad_same; + break; + case schema::PadMode_VALID: + pooling_param->pad_mode_ = Pad_valid; + break; + default: + pooling_param->pad_mode_ = Pad_pad; + break; + } return reinterpret_cast(pooling_param); } @@ -378,6 +413,23 @@ OpParameter *PopulateArithmeticGradParameter(const void *prim) { return reinterpret_cast(arithmetic_param); } +OpParameter *PopulateResizeGradParameter(const void *prim) { + ResizeGradParameter *resize_grad_param = reinterpret_cast(malloc(sizeof(ResizeGradParameter))); + if (resize_grad_param == nullptr) { + MS_LOG(ERROR) << "malloc resize grad parameter failed."; + return nullptr; + } + memset(resize_grad_param, 0, sizeof(ResizeGradParameter)); + auto primitive = static_cast(prim); + resize_grad_param->op_parameter_.type_ = primitive->value_type(); + auto param = primitive->value_as_ResizeGrad(); + + resize_grad_param->method = static_cast(param->method()); + resize_grad_param->align_corners_ = param->align_corners(); + + return reinterpret_cast(resize_grad_param); +} + void PopulateTrainParameters() { lite::Registry ApplyMomentumParameterRegistry(schema::PrimitiveType_ApplyMomentum, PopulateApplyMomentumParameter, lite::SCHEMA_CUR); @@ -437,8 +489,14 @@ void PopulateTrainParameters() { lite::SCHEMA_CUR); lite::Registry StridedSliceGradParameterRegistry(schema::PrimitiveType_StridedSliceGrad, lite::PopulateStridedSliceParameter, lite::SCHEMA_CUR); + lite::Registry SqrtGradParameterRegistry(schema::PrimitiveType_SqrtGrad, lite::DefaultPopulateParameter, + lite::SCHEMA_CUR); + lite::Registry RsqrtGradParameterRegistry(schema::PrimitiveType_RsqrtGrad, lite::DefaultPopulateParameter, + lite::SCHEMA_CUR); + lite::Registry ResizeGradParameterRegistry(schema::PrimitiveType_ResizeGrad, PopulateResizeGradParameter, + lite::SCHEMA_CUR); lite::Registry AbsGradParameterRegistry(schema::PrimitiveType_AbsGrad, lite::DefaultPopulateParameter, lite::SCHEMA_CUR); } - -} // namespace mindspore::kernel +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc index 642b586efc..1f6108a026 100644 --- a/mindspore/lite/src/train/train_session.cc +++ b/mindspore/lite/src/train/train_session.cc @@ -272,7 +272,7 @@ void TrainSession::CompileEvalOutputs() { eval_output_node_map_.clear(); eval_output_tensor_map_.clear(); for (auto kernel : this->train_kernels_) { - if (IsLossKernel(kernel)) { + if (IsLossKernel(kernel) && !(IsGradKernel(kernel))) { for (auto in_kernel : kernel->in_kernels()) { if (IsLossKernel(in_kernel) || IsGradKernel(in_kernel)) continue; // insert if not already in diff --git a/mindspore/lite/src/train/transfer_session.cc b/mindspore/lite/src/train/transfer_session.cc index 0e619517fd..900bb1ff3b 100644 --- a/mindspore/lite/src/train/transfer_session.cc +++ b/mindspore/lite/src/train/transfer_session.cc @@ -33,6 +33,7 @@ #include "src/executor.h" #include "src/kernel_registry.h" #include "src/runtime/kernel/arm/fp32_grad/convolution.h" +#include "nnacl/fp32/pack_fp32.h" namespace mindspore { namespace lite { @@ -54,10 +55,20 @@ TransferSession::TransferSession(const char *model_buf_backbone, size_t size_bac std::vector TransferSession::GetInputs() const { return combined_inputs_; } +bool TransferSession::CompileFormatTransform(tensor::MSTensor *out, tensor::MSTensor *in, int *mask) { + for (std::size_t dim = 0; dim != out->shape().size(); ++dim) { + if (in->shape().at(mask[dim]) != out->shape().at(dim)) { + return false; + } + } + return true; +} + int TransferSession::CompileTransferGraph() { combined_inputs_ = backbone_session_->GetInputs(); auto outputs_backbone = backbone_session_->GetOutputs(); auto inputs_head = lite::TrainSession::GetInputs(); + int ret = RET_OK; for (auto input : inputs_head) { bool match = false; @@ -72,6 +83,11 @@ int TransferSession::CompileTransferGraph() { break; } } + if (match == false && input->shape().size() == 4) { + int nchw2nhwc_mask[4] = {0, 3, 1, 2}; + nchw2nhwc_ = CompileFormatTransform(output, input, nchw2nhwc_mask); + match = nchw2nhwc_; + } if (true == match) { break; } @@ -124,7 +140,14 @@ int TransferSession::RunGraph(const KernelCallBack &before, const KernelCallBack auto output = backbone_head_pair.second; char *input_data = reinterpret_cast(input->MutableData()); char *output_data = reinterpret_cast(output->MutableData()); - std::copy(output_data, output_data + output->Size(), input_data); + if (nchw2nhwc_) { + int plane = input->shape().at(1) * input->shape().at(2); + int batch = input->shape().at(0); + int channel = input->shape().at(3); + PackNCHWToNHWCFp32(output_data, input_data, batch, plane, channel, 0, 1); + } else { + std::copy(output_data, output_data + output->Size(), input_data); + } } ret = lite::TrainSession::RunGraph(before, after); return ret; diff --git a/mindspore/lite/src/train/transfer_session.h b/mindspore/lite/src/train/transfer_session.h index 7a9548ce33..ca1e02b6dc 100644 --- a/mindspore/lite/src/train/transfer_session.h +++ b/mindspore/lite/src/train/transfer_session.h @@ -72,6 +72,8 @@ class TransferSession : public lite::TrainSession { bool is_valid_; private: + bool CompileFormatTransform(tensor::MSTensor *out, tensor::MSTensor *in, int *mask); + bool nchw2nhwc_ = false; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/test/models_ms_train.cfg b/mindspore/lite/test/models_ms_train.cfg index 5724ef959d..2ee4e5cbc8 100644 --- a/mindspore/lite/test/models_ms_train.cfg +++ b/mindspore/lite/test/models_ms_train.cfg @@ -1,3 +1,4 @@ +# mini_alexnet_r1.1 mobilenetv1_r1.1 mobilenetv2_r1.1 @@ -5,4 +6,17 @@ lenet_r1.1 effnet_r1.1 effnet_tune_r1.1 googlenet_r1.1 -#LAST +# mini_alexnet +# nin +# lenet +# mobilenetv1 +# mobilenetv2 +# mobilenetv3 +# effnet +# resnet +# effnet_tune +# googlenet +# densenet +# shufflenetv2 +# xception +# LAST diff --git a/mindspore/lite/test/run_net_export.sh b/mindspore/lite/test/run_net_export.sh index f2c79b46af..095ca7ec83 100755 --- a/mindspore/lite/test/run_net_export.sh +++ b/mindspore/lite/test/run_net_export.sh @@ -47,7 +47,8 @@ logs_path=${basepath}/logs_train rm -rf ${logs_path} mkdir -p ${logs_path} -docker_image=mindspore/mindspore-gpu:1.1.0 +docker_image=mindspore_build:210301 +#docker_image=mindspore/mindspore-gpu:1.1.1 # Export models echo "Start Exporting models ..." # Set log files @@ -65,12 +66,15 @@ if [[ -z "${CLOUD_MODEL_ZOO}" ]]; then fi # Export mindspore train models: +fail=0 while read line; do - model_name=${line} + LFS=" " read -r -a line_array <<< ${line} + model_name=${line_array[0]} if [[ $model_name == \#* ]]; then continue fi echo ${model_name}'_train_export.py' >> "${export_log_file}" + rm -f ${models_path}/${model_name}_train.mindir echo 'exporting' ${model_name} echo 'docker run --user '"$(id -u):$(id -g)"' --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true '${docker_image}' python '${models_path}'/'${model_name}'_train_export.py' >> "${export_log_file}" docker run --user "$(id -u):$(id -g)" --env CLOUD_MODEL_ZOO=${CLOUD_MODEL_ZOO} -w $PWD --runtime=nvidia -v /home/$USER:/home/$USER -v /opt/share:/opt/share --privileged=true "${docker_image}" python ${models_path}'/'${model_name}_train_export.py "${epoch_num}" @@ -78,8 +82,10 @@ while read line; do export_result='export mindspore '${model_name}'_train_export pass';echo ${export_result} >> ${export_result_file} else export_result='export mindspore '${model_name}'_train_export failed';echo ${export_result} >> ${export_result_file} + fail=1 fi done < ${models_mindspore_train_config} Print_Result ${export_result_file} +exit $fail diff --git a/mindspore/lite/test/run_net_train.sh b/mindspore/lite/test/run_net_train.sh index 641531b604..8141847755 100755 --- a/mindspore/lite/test/run_net_train.sh +++ b/mindspore/lite/test/run_net_train.sh @@ -1,7 +1,7 @@ #!/bin/bash # Run Export on x86 platform and create output test files: -docker_image= +docker_image=mindspore_build:210301 function Run_Export(){ cd $models_path || exit 1 if [[ -z "${CLOUD_MODEL_ZOO}" ]]; then @@ -10,7 +10,8 @@ function Run_Export(){ fi # Export mindspore train models: while read line; do - model_name=${line} + LFS=" " read -r -a line_array <<< ${line} + model_name=${line_array[0]} if [[ $model_name == \#* ]]; then continue fi @@ -47,10 +48,11 @@ function Run_Converter() { rm -rf ${ms_models_path} mkdir -p ${ms_models_path} - + fail=0 # Convert mindspore train models: while read line; do - model_name=${line} + LFS=" " read -r -a line_array <<< ${line} + model_name=${line_array[0]} if [[ $model_name == \#* ]]; then continue fi @@ -64,8 +66,10 @@ function Run_Converter() { converter_result='converter mindspore '${model_name}'_train pass';echo ${converter_result} >> ${run_converter_result_file} else converter_result='converter mindspore '${model_name}'_train failed';echo ${converter_result} >> ${run_converter_result_file} + fail=1 fi done < ${models_mindspore_train_config} + return ${fail} } # Run on x86 platform: @@ -73,7 +77,8 @@ function Run_x86() { # Run mindspore converted train models: fail=0 while read line; do - model_name=${line} + LFS=" " read -r -a line_array <<< ${line} + model_name=${line_array[0]} if [[ $model_name == \#* ]]; then continue fi @@ -81,7 +86,7 @@ function Run_x86() { echo ${model_name}'_train' >> "${run_x86_log_file}" echo 'cd '${x86_path}'/mindspore-lite-'${version}'-train-linux-x64' >> "${run_x86_log_file}" cd ${x86_path}/mindspore-lite-${version}-train-linux-x64 || return 1 - echo 'LD_LIBRARY_PATH='${LD_LIBRARY_PATH}':./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib ./benchmark_train/benchmark_train --epochs='${epoch_num}' --modelFile='${ms_models_path}'/'${model_name}'_train.ms --inDataFile='${train_io_path}/${model_name}_input1.bin,${train_io_path}/${model_name}_input2.bin' --expectedDataFile='${train_io_path}'/'${model_name}'_output' >> "${run_x86_log_file}" + echo 'LD_LIBRARY_PATH='${LD_LIBRARY_PATH}':./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib ./benchmark_train/benchmark_train --epochs='${epoch_num}' --modelFile='${ms_models_path}'/'${model_name}'_train.ms --inDataFile='${train_io_path}/${model_name}_input1.bin,${train_io_path}/${model_name}_input2.bin' --expectedDataFile='${train_io_path}'/'${model_name}'_output --exportFile='${ms_models_path}'/'${model_name}'_train_exported.ms' >> "${run_x86_log_file}" echo '-------------------------------------------------------------------------------' >> "${run_x86_log_file}" LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:./lib:./third_party/libjpeg-turbo/lib:./third_party/opencv/lib:./minddata/lib:./minddata/third_party/libjpeg-turbo/lib \ ${run_valgrind}./benchmark_train/benchmark_train \ @@ -159,10 +164,16 @@ function Run_arm() { fail=0 # Run mindir converted train models: while read line; do - model_name=${line} + LFS=" " read -r -a line_array <<< ${line} + model_name=${line_array[0]} if [[ $model_name == \#* ]]; then continue fi + if [[ "${line_array[1]}" == "noarm32" ]] && [[ "$1" == arm32 ]]; then + run_result=$1': '${model_name}'_train irrelevant'; echo ${run_result} >> ${run_benchmark_train_result_file} + continue + fi + # run benchmark_train test without clib data echo ${model_name}'_train' >> "${run_arm_log_file}" @@ -339,7 +350,7 @@ START=$(date +%s.%N) # Run converter echo "start run converter ..." -Run_Converter +Run_Converter & Run_converter_PID=$! sleep 1 diff --git a/mindspore/lite/test/ut/nnacl/infer/maximum_grad_infer_test.cc b/mindspore/lite/test/ut/nnacl/infer/maximum_grad_infer_test.cc index afcce9fcc3..9d8e87fa77 100644 --- a/mindspore/lite/test/ut/nnacl/infer/maximum_grad_infer_test.cc +++ b/mindspore/lite/test/ut/nnacl/infer/maximum_grad_infer_test.cc @@ -15,6 +15,7 @@ */ #include "common/common_test.h" #include "mindspore/lite/nnacl/infer/maximum_grad_infer.h" +#include "mindspore/lite/nnacl/arithmetic.h" namespace mindspore { @@ -44,7 +45,7 @@ TEST_F(MaximumGradInferTest, MaximumGradInferTest0) { std::vector outputs(2, NULL); outputs[0] = new TensorC; outputs[1] = new TensorC; - MaximumGradParameter *parameter = new MaximumGradParameter; + ArithmeticParameter *parameter = new ArithmeticParameter; parameter->op_parameter_.infer_flag_ = true; int ret = MaximumGradInferShape((const TensorC **)inputs.data(), inputs.size(), outputs.data(), outputs.size(), reinterpret_cast(parameter)); @@ -60,18 +61,18 @@ TEST_F(MaximumGradInferTest, MaximumGradInferTest0) { ASSERT_EQ(outputs[1]->data_type_, kNumberTypeInt32); ASSERT_EQ(outputs[1]->format_, Format_NHWC); ASSERT_EQ(parameter->ndim_, 3); - ASSERT_EQ(parameter->dy_shape_size_, 3); - ASSERT_EQ(parameter->dy_shape_[0], 7); - ASSERT_EQ(parameter->dy_shape_[1], 8); - ASSERT_EQ(parameter->dy_shape_[2], 9); - ASSERT_EQ(parameter->x1_shape_size_, 3); - ASSERT_EQ(parameter->x1_shape_[0], 1); - ASSERT_EQ(parameter->x1_shape_[1], 4); - ASSERT_EQ(parameter->x1_shape_[2], 3); - ASSERT_EQ(parameter->x2_shape_size_, 3); - ASSERT_EQ(parameter->x2_shape_[0], 1); - ASSERT_EQ(parameter->x2_shape_[1], 5); - ASSERT_EQ(parameter->x2_shape_[2], 6); + ASSERT_EQ(parameter->out_elements_num_, 3); + ASSERT_EQ(parameter->out_shape_[0], 7); + ASSERT_EQ(parameter->out_shape_[1], 8); + ASSERT_EQ(parameter->out_shape_[2], 9); + ASSERT_EQ(parameter->in_elements_num0_, 3); + ASSERT_EQ(parameter->in_shape0_[0], 1); + ASSERT_EQ(parameter->in_shape0_[1], 4); + ASSERT_EQ(parameter->in_shape0_[2], 3); + ASSERT_EQ(parameter->in_elements_num1_, 3); + ASSERT_EQ(parameter->in_shape1_[0], 1); + ASSERT_EQ(parameter->in_shape1_[1], 5); + ASSERT_EQ(parameter->in_shape1_[2], 6); delete parameter; for (size_t i = 0; i < inputs_size; i++) { delete inputs[i]; diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 0d9c808ec5..2776f554a8 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -98,7 +98,8 @@ void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { MS_LOG(ERROR) << "value node is invalid."; return; } - if (value_node->value() != nullptr && opt::CheckPrimitiveType(make_tuple_node, opt::kPrimMakeTuple)) { + if (value_node->value() != nullptr && (opt::CheckPrimitiveType(make_tuple_node, opt::kPrimMakeTuple) || + opt::CheckPrimitiveType(make_tuple_node, opt::kPrimMakeTupleV2))) { has_make_tuple = true; for (size_t j = 1; j < make_tuple_node->inputs().size(); ++j) { inputs.emplace_back(make_tuple_node->input(j)); @@ -360,6 +361,9 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptrname() == mindspore::ops::kNameDepend || prim->name() == mindspore::ops::kNameControlDepend) { continue; } + if (prim->name() == "make_tuple") { + continue; + } if (prim->name() == mindspore::ops::kNameTupleGetItem || prim->name() == mindspore::ops::kNameMakeTuple) { continue; @@ -769,7 +773,7 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr &input_ano MS_LOG(INFO) << "op name:" << input_anode->fullname_with_scope() << " input is func_graph"; return RET_OK; } else if (value->isa()) { - MS_LOG(INFO) << "value is a monad."; + MS_LOG(INFO) << "op name:" << input_anode->fullname_with_scope() << " input is Monad"; return RET_OK; } else { MS_LOG(ERROR) << "Not support value type , need add support."; diff --git a/mindspore/lite/tools/benchmark_train/net_train.cc b/mindspore/lite/tools/benchmark_train/net_train.cc index 96d425a896..b91e3473b0 100644 --- a/mindspore/lite/tools/benchmark_train/net_train.cc +++ b/mindspore/lite/tools/benchmark_train/net_train.cc @@ -125,7 +125,8 @@ int NetTrain::ReadInputFile() { return RET_ERROR; } else { if (ms_inputs_.size() > flags_->input_data_list_.size()) { - MS_LOG(ERROR) << "missing input files"; + MS_LOG(ERROR) << "missing input files expecting " << ms_inputs_.size() << ",got " + << flags_->input_data_list_.size(); return RET_ERROR; } for (size_t i = 0; i < ms_inputs_.size(); i++) { @@ -327,8 +328,8 @@ int NetTrain::RunExportedNet() { context->thread_num_ = flags_->num_threads_; session_ = session::TrainSession::CreateSession(flags_->export_file_.c_str(), context.get()); if (session_ == nullptr) { - MS_LOG(ERROR) << "CreateSession failed while running ", model_name.c_str(); - std::cout << "CreateSession failed while running ", model_name.c_str(); + MS_LOG(ERROR) << "ExportedFile CreateSession failed while running " << model_name.c_str(); + std::cout << "CreateSession failed while running " << model_name.c_str() << std::endl; return RET_ERROR; } ms_inputs_ = session_->GetInputs(); @@ -344,13 +345,6 @@ int NetTrain::RunExportedNet() { return status; } - status = session_->RunGraph(); - if (status != RET_OK) { - MS_LOG(ERROR) << "Inference error " << status; - std::cerr << "Inference error " << status << std::endl; - return status; - } - if (!flags_->data_file_.empty()) { MS_LOG(INFO) << "Check accuracy for exported model"; std::cout << "Check accuracy for exported model " << std::endl; @@ -391,11 +385,13 @@ int NetTrain::RunNetTrain() { } else { context->device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = NO_BIND; } + + layer_checksum_ = flags_->layer_checksum_; context->thread_num_ = flags_->num_threads_; session_ = session::TrainSession::CreateSession(flags_->model_file_.c_str(), context.get()); if (session_ == nullptr) { - MS_LOG(ERROR) << "CreateSession failed while running ", model_name.c_str(); - std::cout << "CreateSession failed while running ", model_name.c_str(); + MS_LOG(ERROR) << "RunNetTrain CreateSession failed while running " << model_name.c_str(); + std::cout << "RunNetTrain CreateSession failed while running " << model_name.c_str() << std::endl; return RET_ERROR; } @@ -501,7 +497,6 @@ int NetTrain::InitCallbackParameter() { if (op_times_by_name_.find(callParam.node_name) == op_times_by_name_.end()) { op_times_by_name_.insert(std::make_pair(callParam.node_name, std::make_pair(0, 0.0f))); } - op_call_times_total_++; op_begin_ = GetTimeUs(); return true; @@ -526,9 +521,14 @@ int NetTrain::InitCallbackParameter() { op_times_by_type_[call_param.node_type].second += cost; op_times_by_name_[call_param.node_name].first++; op_times_by_name_[call_param.node_name].second += cost; + if (layer_checksum_) { + float *output = reinterpret_cast(after_outputs.at(0)->MutableData()); + float sum = 0; + for (int i = 0; i < after_outputs.at(0)->ElementsNum(); i++) sum += output[i]; + std::cout << call_param.node_type << " shape= " << after_outputs.at(0)->shape() << " sum=" << sum << "\n"; + } return true; }; - return RET_OK; } diff --git a/mindspore/lite/tools/benchmark_train/net_train.h b/mindspore/lite/tools/benchmark_train/net_train.h index 4c4990aa31..7abd01b96c 100644 --- a/mindspore/lite/tools/benchmark_train/net_train.h +++ b/mindspore/lite/tools/benchmark_train/net_train.h @@ -29,6 +29,7 @@ #include #include #include +#include #include "tools/common/flag_parser.h" #include "src/common/file_utils.h" #include "src/common/utils.h" @@ -64,6 +65,7 @@ class MS_API NetTrainFlags : public virtual FlagParser { AddFlag(&NetTrainFlags::data_file_, "expectedDataFile", "Expected results data file path", ""); AddFlag(&NetTrainFlags::export_file_, "exportFile", "MS File to export trained model into", ""); AddFlag(&NetTrainFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5); + AddFlag(&NetTrainFlags::layer_checksum_, "layerCheckSum", "layer output checksum print (debug)", false); } ~NetTrainFlags() override = default; @@ -92,6 +94,7 @@ class MS_API NetTrainFlags : public virtual FlagParser { // Resize std::string export_file_ = ""; std::string resize_dims_in_ = ""; + bool layer_checksum_ = false; std::vector> resize_dims_; }; @@ -142,11 +145,16 @@ class MS_API NetTrain { size_t errorCount = 0; float meanError = 0; std::cout << "Data of model output: "; + for (int j = 0; j < std::min(50, size); j++) { + std::cout << static_cast(msTensorData[j]) << " "; + } + std::cout << std::endl; + std::cout << "Data of Ref output : "; + for (int j = 0; j < std::min(50, size); j++) { + std::cout << refOutput[j] << " "; + } for (int j = 0; j < size; j++) { - if (j < 50) { - std::cout << static_cast(msTensorData[j]) << " "; - } - + std::cout << std::endl; if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) { std::cerr << "Output tensor has nan or inf data, compare fail" << std::endl; MS_LOG(ERROR) << "Output tensor has nan or inf data, compare fail"; @@ -205,6 +213,7 @@ class MS_API NetTrain { mindspore::KernelCallBack before_call_back_; mindspore::KernelCallBack after_call_back_; + bool layer_checksum_ = false; }; int MS_API RunNetTrain(int argc, const char **argv); diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc index e2cab08311..33b3313193 100644 --- a/mindspore/lite/tools/common/node_util.cc +++ b/mindspore/lite/tools/common/node_util.cc @@ -33,6 +33,7 @@ static const std::vector nhwcOpList = {schema::PrimitiveT schema::PrimitiveType_ApplyMomentum, schema::PrimitiveType_SGD, schema::PrimitiveType_Adam, + schema::PrimitiveType_ResizeGrad, schema::PrimitiveType_AvgPoolFusion, schema::PrimitiveType_MaxPoolFusion, schema::PrimitiveType_Conv2DFusion, @@ -51,8 +52,9 @@ static const std::vector nhwcOpList = {schema::PrimitiveT schema::PrimitiveType_SpaceToBatchND}; static const std::vector nhwcOpAllInputList = { - schema::PrimitiveType_AvgPoolGrad, schema::PrimitiveType_MaxPoolGrad, schema::PrimitiveType_ActivationGrad, - schema::PrimitiveType_Conv2DBackpropFilterFusion, schema::PrimitiveType_BatchNormGrad}; + schema::PrimitiveType_AvgPoolGrad, schema::PrimitiveType_MaxPoolGrad, + schema::PrimitiveType_ActivationGrad, schema::PrimitiveType_Conv2DBackpropFilterFusion, + schema::PrimitiveType_BatchNormGrad, schema::PrimitiveType_ResizeGrad}; // index {} mean all inputs need insert static std::unordered_map> extNhwcInsertIndex = { diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index 1c9cdc801a..0f68a1336d 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -128,6 +128,7 @@ int RunConverter(int argc, const char **argv) { oss << "CONVERT RESULT FAILED:" << status << " " << GetErrorInfo(status); MS_LOG(ERROR) << oss.str(); std::cout << oss.str() << std::endl; + status = RET_ERROR; return status; } diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc index 4f52a04fdc..9dd95fb963 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc @@ -172,6 +172,11 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { } } else if (IsContain(GetNhwcAllInputOpList(), opType)) { auto input_size = node->inputIndex.size(); + if (GetCNodeTType(**iter) == schema::PrimitiveType_ResizeGrad) { + if ((**iter).primitive->value.AsResizeGrad()->method == schema::ResizeMethod_NEAREST) { + input_size = 1; + } + } for (size_t i = 0; i < input_size; i++) { iter = InsertFormatTransNode(graph, iter, kBefore, i, beforeNodeType, &status); if (status != RET_OK) { diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.h b/mindspore/lite/tools/optimizer/common/gllo_utils.h index cbcc11557f..1db8048ec7 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.h +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_ -#define MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_ +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_GLLO_UTILS_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_GLLO_UTILS_H_ #include #include @@ -37,6 +37,7 @@ namespace mindspore { namespace opt { inline const PrimitivePtr kPrimReturn = std::make_shared("Return"); inline const PrimitivePtr kPrimMakeTuple = std::make_shared("MakeTuple"); +inline const PrimitivePtr kPrimMakeTupleV2 = std::make_shared("make_tuple"); inline const PrimitivePtr kPrimIdentity = std::make_shared("Identity"); std::vector CastToInt(const ValuePtr &value); @@ -146,4 +147,4 @@ ParameterPtr BuildFloatValueParameterNode(const FuncGraphPtr &func_graph, const const std::string &node_name); } // namespace opt } // namespace mindspore -#endif // MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_ +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_GLLO_UTILS_H_ diff --git a/mindspore/lite/tools/optimizer/graph/primitive_adjust_pass.cc b/mindspore/lite/tools/optimizer/graph/primitive_adjust_pass.cc index ef604e1fa6..9e944c4d81 100644 --- a/mindspore/lite/tools/optimizer/graph/primitive_adjust_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/primitive_adjust_pass.cc @@ -131,7 +131,12 @@ constexpr auto kNameHSwishGrad = "HSwishGrad"; constexpr auto kNameReluGrad = "ReluGrad"; constexpr auto kNameReLU6Grad = "ReLU6Grad"; constexpr auto kNameSigmoidGrad = "SigmoidGrad"; +constexpr auto kNameEluGrad = "EluGrad"; +constexpr auto kNameGeluGrad = "GeluGrad"; constexpr auto kNameSlice = "Slice"; +constexpr auto kNameAvgPoolGradGpu = "AvgPoolGradGpu"; +constexpr auto kNameAvgPoolGradCpu = "AvgPoolGradCpu"; + std::map activation_map = {{ops::kNameElu, mindspore::ELU}, {ops::kNameGeLU, mindspore::GELU}, {ops::kNameLeakyRelu, mindspore::LEAKY_RELU}, @@ -145,7 +150,9 @@ std::map activation_map = {{ops::kNameEl {kNameHSwishGrad, mindspore::HSWISH}, {kNameReluGrad, mindspore::RELU}, {kNameReLU6Grad, mindspore::RELU6}, - {kNameSigmoidGrad, mindspore::SIGMOID}}; + {kNameSigmoidGrad, mindspore::SIGMOID}, + {kNameEluGrad, mindspore::ELU}, + {kNameGeluGrad, mindspore::GELU}}; std::map reduce_map = { {ops::kNameReduceAll, mindspore::Reduce_All}, {ops::kNameReduceASum, mindspore::Reduce_ASum}, @@ -351,16 +358,29 @@ int MoveAttrPoolGrad(const CNodePtr &cnode) { MS_LOG(ERROR) << "value node is invalid."; return lite::RET_ERROR; } - auto status = AttrAdjust(src_prim, ops::kKernelSize, {2, 3}); + PrimitivePtr dst_prim; + if (src_prim->name() == kNameAvgPoolGrad || src_prim->name() == kNameAvgPoolGradGpu || + src_prim->name() == kNameAvgPoolGradCpu) { + dst_prim = std::make_shared(); + } else if (src_prim->name() == kNameMaxPoolGrad) { + dst_prim = std::make_shared(); + } else { + MS_LOG(ERROR) << "unsupported pooling type."; + return lite::RET_ERROR; + } + MS_ASSERT(dst_prim != nullptr); + dst_prim->SetAttrs(src_prim->attrs()); + auto status = AttrAdjust(dst_prim, ops::kKernelSize, {2, 3}); if (status != lite::RET_OK) { MS_LOG(ERROR) << "adjust ksize failed."; return status; } - status = AttrAdjust(src_prim, ops::kStrides, {2, 3}); + status = AttrAdjust(dst_prim, ops::kStrides, {2, 3}); if (status != lite::RET_OK) { MS_LOG(ERROR) << "adjust strides failed."; return status; } + value_node->set_value(dst_prim); return lite::RET_OK; } @@ -510,6 +530,8 @@ REGIST_PRIMITIVE_ADJUST(kNameArgMin, MoveAttrMapCommon) REGIST_PRIMITIVE_ADJUST(kNameArgMinWithValue, MoveAttrMapCommon) REGIST_PRIMITIVE_ADJUST(kNameAvgPool, MoveAttrPool) REGIST_PRIMITIVE_ADJUST(kNameAvgPoolGrad, MoveAttrPoolGrad) +REGIST_PRIMITIVE_ADJUST(kNameAvgPoolGradGpu, MoveAttrPoolGrad) +REGIST_PRIMITIVE_ADJUST(kNameAvgPoolGradCpu, MoveAttrPoolGrad) REGIST_PRIMITIVE_ADJUST(kNameBatchMatMul, MoveAttrMapCommon) REGIST_PRIMITIVE_ADJUST(kNameBatchNorm, MoveAttrMapCommon) REGIST_PRIMITIVE_ADJUST(kNameConv2DBackpropFilter, MoveAttrMapCommon) @@ -519,10 +541,12 @@ REGIST_PRIMITIVE_ADJUST(kNameDepthWiseConv2D, MoveAttrMapConv2D) REGIST_PRIMITIVE_ADJUST(kNameConv2dTranspose, MoveAttrMapCommon) REGIST_PRIMITIVE_ADJUST(kNameDiv, MoveAttrMapCommon) REGIST_PRIMITIVE_ADJUST(kNameElu, MoveAttrMapActivation) +REGIST_PRIMITIVE_ADJUST(kNameEluGrad, MoveAttrMapActivationGrad) REGIST_PRIMITIVE_ADJUST(kNameExp, MoveAttrMapCommon) REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormEx, MoveAttrMapCommon) REGIST_PRIMITIVE_ADJUST(kNameFusedBatchNormGradEx, MoveAttrMapCommon) REGIST_PRIMITIVE_ADJUST(kNameGeLU, MoveAttrMapActivation) +REGIST_PRIMITIVE_ADJUST(kNameGeluGrad, MoveAttrMapActivationGrad) REGIST_PRIMITIVE_ADJUST(kNameHSigmoid, MoveAttrMapActivation) REGIST_PRIMITIVE_ADJUST(kNameHSigmoidGrad, MoveAttrMapActivationGrad) REGIST_PRIMITIVE_ADJUST(kNameHSwish, MoveAttrMapActivation) diff --git a/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc b/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc index 046998ba63..7cb738c0e5 100644 --- a/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc @@ -35,6 +35,21 @@ int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraph return lite::RET_NO_CHANGE; } } + if (CheckPrimitiveType(anf_node, prim::kPrimDepend)) { + if (cnode->size() != InputDoubleNum) { + MS_LOG(DEBUG) << "The node inputs size is bigger than 1"; + remove_cnode_.insert(anf_node); + return lite::RET_NO_CHANGE; + } + } + if (CheckPrimitiveType(anf_node, prim::kPrimControlDepend)) { + if (cnode->size() != InputDoubleNum) { + MS_LOG(DEBUG) << "The node inputs size is bigger than 1"; + remove_cnode_.insert(anf_node); + return lite::RET_NO_CHANGE; + } + } + bool replace_succ = manager->Replace(anf_node, cnode->input(1)); if (!replace_succ) { MS_LOG(ERROR) << "replace redundant op failed."; diff --git a/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc b/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc index 7f7032db17..c83ade69d0 100644 --- a/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc @@ -70,7 +70,9 @@ lite::STATUS WeightFormatTransformPass::TransposeInsertForWeightSharing(const Fu continue; } if (CheckPrimitiveType(node, prim::kPrimConv2DFusion) || CheckPrimitiveType(node, kPrimConv2DBackpropInputFusion) || - CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion)) { + CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion) || + CheckPrimitiveType(node, prim::kPrimApplyMomentum) || CheckPrimitiveType(node, prim::kPrimSGD) || + CheckPrimitiveType(node, prim::kPrimAdam)) { continue; } auto cnode = node->cast();