From 33d77419048efea99398954098f58a36321c8735 Mon Sep 17 00:00:00 2001 From: yoni Date: Sun, 17 Jan 2021 18:17:30 +0200 Subject: [PATCH] tod add train loop --- build.sh | 34 +- mindspore/lite/examples/train_lenet/Makefile | 2 +- .../train_lenet/src/accuracy_monitor.h | 43 +++ .../train_lenet/src/data_callbacks.cc | 103 ++++++ .../examples/train_lenet/src/data_loader.h | 34 ++ .../examples/train_lenet/src/net_runner.cc | 130 ++------ .../examples/train_lenet/src/net_runner.h | 4 +- .../transfer_learning/src/net_runner.cc | 5 +- mindspore/lite/include/train/ckpt_saver.h | 51 +++ .../classification_train_accuracy_monitor.h | 48 +++ mindspore/lite/include/train/loss_monitor.h | 47 +++ mindspore/lite/include/train/lr_scheduler.h | 59 ++++ mindspore/lite/include/train/train_loop.h | 69 ++++ .../lite/include/train/train_loop_callback.h | 57 ++++ mindspore/lite/include/train_session.h | 17 + .../java/com/mindspore/lite/TrainSession.java | 178 ++++++++++ .../java/app/src/main/native/CMakeLists.txt | 46 ++- .../src/main/native/runtime/train_session.cpp | 305 ++++++++++++++++++ mindspore/lite/nnacl/fp32/arithmetic_fp32.c | 4 +- .../lite/nnacl/fp32_grad/activation_grad.c | 44 ++- mindspore/lite/nnacl/fp32_grad/batch_norm.c | 69 ++-- mindspore/lite/nnacl/fp32_grad/batch_norm.h | 10 +- .../fp32_grad/binary_cross_entropy_grad.c | 2 +- mindspore/lite/nnacl/fp32_grad/gemm.c | 7 +- mindspore/lite/nnacl/fp32_grad/gemm.h | 1 + mindspore/lite/nnacl/fp32_grad/pack_ext.c | 112 ++++--- mindspore/lite/schema/ops.fbs | 4 +- mindspore/lite/src/CMakeLists.txt | 5 +- mindspore/lite/src/ops/pad.cc | 38 ++- mindspore/lite/src/ops/primitive_c.cc | 7 +- .../kernel/arm/fp32/convolution_1x1_fp32.cc | 44 +++ .../kernel/arm/fp32/convolution_1x1_fp32.h | 8 +- .../arm/fp32/convolution_delegate_fp32.h | 9 + .../fp32/convolution_depthwise_3x3_fp32.cc | 18 ++ .../arm/fp32/convolution_depthwise_3x3_fp32.h | 8 +- .../arm/fp32/convolution_depthwise_fp32.cc | 18 ++ .../arm/fp32/convolution_depthwise_fp32.h | 8 +- .../convolution_depthwise_indirect_fp32.cc | 23 ++ .../convolution_depthwise_indirect_fp32.h | 8 +- .../convolution_depthwise_slidewindow_fp32.cc | 19 ++ .../convolution_depthwise_slidewindow_fp32.h | 8 +- .../kernel/arm/fp32/convolution_fp32.cc | 36 +++ .../kernel/arm/fp32/convolution_fp32.h | 3 + .../arm/fp32/convolution_winograd_fp32.cc | 18 +- .../arm/fp32/convolution_winograd_fp32.h | 7 +- .../kernel/arm/fp32_grad/activation_grad.cc | 26 +- .../kernel/arm/fp32_grad/activation_grad.h | 3 +- .../src/runtime/kernel/arm/fp32_grad/adam.cc | 57 ++-- .../src/runtime/kernel/arm/fp32_grad/adam.h | 9 +- .../kernel/arm/fp32_grad/apply_momentum.cc | 35 +- .../kernel/arm/fp32_grad/apply_momentum.h | 11 +- .../arm/fp32_grad/arithmetic_self_grad.cc | 27 +- .../arm/fp32_grad/arithmetic_self_grad.h | 4 +- .../runtime/kernel/arm/fp32_grad/assign.cc | 15 +- .../src/runtime/kernel/arm/fp32_grad/assign.h | 5 +- .../runtime/kernel/arm/fp32_grad/bias_grad.cc | 9 +- .../runtime/kernel/arm/fp32_grad/bn_grad.cc | 42 ++- .../kernel/arm/fp32_grad/convolution.cc | 57 +++- .../kernel/arm/fp32_grad/convolution.h | 7 +- .../arm/fp32_grad/convolution_grad_filter.cc | 78 +++-- .../arm/fp32_grad/convolution_grad_filter.h | 9 +- .../arm/fp32_grad/convolution_grad_input.cc | 41 ++- .../arm/fp32_grad/convolution_grad_input.h | 7 +- .../runtime/kernel/arm/fp32_grad/dropout.cc | 15 +- .../runtime/kernel/arm/fp32_grad/dropout.h | 5 +- .../kernel/arm/fp32_grad/dropout_grad.cc | 10 +- .../kernel/arm/fp32_grad/dropout_grad.h | 3 +- .../runtime/kernel/arm/fp32_grad/neg_grad.cc | 30 +- .../runtime/kernel/arm/fp32_grad/neg_grad.h | 4 +- .../kernel/arm/fp32_grad/pooling_grad.cc | 4 +- .../kernel/arm/fp32_grad/power_grad.cc | 17 +- .../runtime/kernel/arm/fp32_grad/power_grad.h | 3 +- .../src/runtime/kernel/arm/fp32_grad/sgd.cc | 38 ++- .../src/runtime/kernel/arm/fp32_grad/sgd.h | 11 +- .../kernel/arm/fp32_grad/smooth_l1_loss.cc | 13 +- .../kernel/arm/fp32_grad/smooth_l1_loss.h | 5 +- .../arm/fp32_grad/smooth_l1_loss_grad.cc | 13 +- .../arm/fp32_grad/smooth_l1_loss_grad.h | 5 +- .../softmax_cross_entropy_with_logits.cc | 4 +- .../kernel/arm/fp32_grad/tuple_getitem.cc | 13 +- .../kernel/arm/fp32_grad/tuple_getitem.h | 3 +- .../classification_train_accuracy_monitor.cc | 98 ++++++ mindspore/lite/src/train/loss_monitor.cc | 66 ++++ mindspore/lite/src/train/lr_scheduler.cc | 75 +++++ mindspore/lite/src/train/optimizer_kernel.h | 35 ++ mindspore/lite/src/train/train_loop.cc | 99 ++++++ mindspore/lite/src/train/train_loop.h | 59 ++++ .../src/train/train_populate_parameter.cc | 16 +- mindspore/lite/src/train/train_session.cc | 33 +- mindspore/lite/src/train/train_session.h | 7 +- mindspore/lite/test/models_ms_train.cfg | 8 +- mindspore/lite/test/run_net_train.sh | 14 +- .../lite/tools/anf_exporter/anf_exporter.cc | 4 +- .../lite/tools/benchmark_train/net_train.cc | 4 +- .../lite/tools/benchmark_train/net_train.h | 7 +- mindspore/lite/tools/common/node_util.cc | 8 +- 96 files changed, 2414 insertions(+), 527 deletions(-) create mode 100644 mindspore/lite/examples/train_lenet/src/accuracy_monitor.h create mode 100644 mindspore/lite/examples/train_lenet/src/data_callbacks.cc create mode 100644 mindspore/lite/examples/train_lenet/src/data_loader.h create mode 100644 mindspore/lite/include/train/ckpt_saver.h create mode 100644 mindspore/lite/include/train/classification_train_accuracy_monitor.h create mode 100644 mindspore/lite/include/train/loss_monitor.h create mode 100644 mindspore/lite/include/train/lr_scheduler.h create mode 100644 mindspore/lite/include/train/train_loop.h create mode 100644 mindspore/lite/include/train/train_loop_callback.h create mode 100644 mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/TrainSession.java create mode 100644 mindspore/lite/java/java/app/src/main/native/runtime/train_session.cpp create mode 100644 mindspore/lite/src/train/classification_train_accuracy_monitor.cc create mode 100644 mindspore/lite/src/train/loss_monitor.cc create mode 100644 mindspore/lite/src/train/lr_scheduler.cc create mode 100644 mindspore/lite/src/train/optimizer_kernel.h create mode 100644 mindspore/lite/src/train/train_loop.cc create mode 100644 mindspore/lite/src/train/train_loop.h diff --git a/build.sh b/build.sh index 172fc758be..168ec20652 100755 --- a/build.sh +++ b/build.sh @@ -395,6 +395,7 @@ if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" || "X$ENABLE_GPU" = " git submodule update --init --recursive akg fi + build_exit() { echo "$@" >&2 @@ -596,33 +597,40 @@ build_lite() build_lite_java_arm64() { # build mindspore-lite arm64 - if [[ "X$INC_BUILD" = "Xoff" ]] || [[ ! -f "${BASEPATH}/output/mindspore-lite-${VERSION_STR}-inference-android-aarch64.tar.gz" ]]; then + JTARBALL=mindspore-lite-${VERSION_STR}-inference-android-aarch64 + if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then + JTARBALL=mindspore-lite-${VERSION_STR}-train-android-aarch64 + fi + if [[ "X$INC_BUILD" = "Xoff" ]] || [[ ! -f "${BASEPATH}/output/${JTARBALL}.tar.gz" ]]; then build_lite "arm64" "off" fi # copy arm64 so cd ${BASEPATH}/output/ - rm -rf mindspore-lite-${VERSION_STR}-inference-android-aarch64 - tar -zxvf mindspore-lite-${VERSION_STR}-inference-android-aarch64.tar.gz + rm -rf ${JTARBALL} + tar -zxvf ${JTARBALL}.tar.gz [ -n "${JAVA_PATH}" ] && rm -rf ${JAVA_PATH}/java/app/libs/arm64-v8a/ mkdir -p ${JAVA_PATH}/java/app/libs/arm64-v8a/ - cp ${BASEPATH}/output/mindspore-lite-${VERSION_STR}-inference-android-aarch64/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/ - echo mindspore-lite-${VERSION_STR}-inference-android-aarch64 - [ -n "${VERSION_STR}" ] && rm -rf mindspore-lite-${VERSION_STR}-inference-android-aarch64 + cp ${BASEPATH}/output/${JTARBALL}/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/ + [ -n "${VERSION_STR}" ] && rm -rf ${JTARBALL} } build_lite_java_arm32() { # build mindspore-lite arm32 - if [[ "X$INC_BUILD" = "Xoff" ]] || [[ ! -f "${BASEPATH}/output/mindspore-lite-${VERSION_STR}-inference-android-aarch32.tar.gz" ]]; then + JTARBALL=mindspore-lite-${VERSION_STR}-inference-android-aarch32 + if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then + JTARBALL=mindspore-lite-${VERSION_STR}-train-android-aarch32 + fi + if [[ "X$INC_BUILD" = "Xoff" ]] || [[ ! -f "${BASEPATH}/output/${JTARBALL}.tar.gz" ]]; then build_lite "arm32" "off" fi # copy arm32 so cd ${BASEPATH}/output/ - rm -rf mindspore-lite-${VERSION_STR}-inference-android-aarch32 - tar -zxvf mindspore-lite-${VERSION_STR}-inference-android-aarch32.tar.gz + rm -rf ${JTARBALL} + tar -zxvf ${JTARBALL}.tar.gz [ -n "${JAVA_PATH}" ] && rm -rf ${JAVA_PATH}/java/app/libs/armeabi-v7a/ mkdir -p ${JAVA_PATH}/java/app/libs/armeabi-v7a/ - cp ${BASEPATH}/output/mindspore-lite-${VERSION_STR}-inference-android-aarch32/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/ - [ -n "${VERSION_STR}" ] && rm -rf mindspore-lite-${VERSION_STR}-inference-android-aarch32 + cp ${BASEPATH}/output/${JTARBALL}/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/ + [ -n "${VERSION_STR}" ] && rm -rf ${JTARBALL} } build_jni_arm64() { @@ -635,7 +643,7 @@ build_jni_arm64() { -DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="arm64-v8a" -DANDROID_TOOLCHAIN_NAME="aarch64-linux-android-clang" \ -DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \ -DANDROID_STL="c++_static" -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DENABLE_VERBOSE=${ENABLE_VERBOSE} \ - -DPLATFORM_ARM64=on "${JAVA_PATH}/java/app/src/main/native" + -DSUPPORT_TRAIN=${SUPPORT_TRAIN} -DPLATFORM_ARM64=on "${JAVA_PATH}/java/app/src/main/native" make -j$THREAD_NUM if [[ $? -ne 0 ]]; then echo "---------------- mindspore lite: build jni arm64 failed----------------" @@ -655,7 +663,7 @@ build_jni_arm32() { -DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="armeabi-v7a" -DANDROID_TOOLCHAIN_NAME="aarch64-linux-android-clang" \ -DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \ -DANDROID_STL="c++_static" -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DENABLE_VERBOSE=${ENABLE_VERBOSE} \ - -DPLATFORM_ARM32=on "${JAVA_PATH}/java/app/src/main/native" + -DSUPPORT_TRAIN=${SUPPORT_TRAIN} -DPLATFORM_ARM32=on "${JAVA_PATH}/java/app/src/main/native" make -j$THREAD_NUM if [[ $? -ne 0 ]]; then echo "---------------- mindspore lite: build jni arm32 failed----------------" diff --git a/mindspore/lite/examples/train_lenet/Makefile b/mindspore/lite/examples/train_lenet/Makefile index 46d750847d..381fd8b5cd 100644 --- a/mindspore/lite/examples/train_lenet/Makefile +++ b/mindspore/lite/examples/train_lenet/Makefile @@ -3,7 +3,7 @@ APP:=bin/net_runner MSLIB:=mindspore-lite MSDIR:=$(realpath package-$(TARGET)/lib) -SRC:=src/net_runner.cc src/dataset.cc +SRC:=src/net_runner.cc src/dataset.cc src/data_callbacks.cc OBJ:=$(SRC:.cc=.o) CFLAGS := -Ofast -std=c++17 \ diff --git a/mindspore/lite/examples/train_lenet/src/accuracy_monitor.h b/mindspore/lite/examples/train_lenet/src/accuracy_monitor.h new file mode 100644 index 0000000000..dc04199b66 --- /dev/null +++ b/mindspore/lite/examples/train_lenet/src/accuracy_monitor.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_ACCURACY_MONITOR_H_ +#define MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_ACCURACY_MONITOR_H_ +#include +#include +#include +#include +#include "include/train/train_loop.h" +#include "src/dataset.h" + +using GraphPoint = std::pair; + +class AccuracyMonitor : public mindspore::session::TrainLoopCallBack { + public: + explicit AccuracyMonitor(DataSet *dataset, int check_every_n, int max_steps = -1) + : ds_(dataset), check_every_n_(check_every_n), max_steps_(max_steps) {} + + int EpochEnd(const mindspore::session::TrainLoopCallBackData &cb_data) override; + + const std::vector &GetAccuracyPoints() const { return accuracies_; } + + private: + DataSet *ds_; + std::vector accuracies_; + int check_every_n_; + int max_steps_; +}; + +#endif // MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_ACCURACY_MONITOR_H_ diff --git a/mindspore/lite/examples/train_lenet/src/data_callbacks.cc b/mindspore/lite/examples/train_lenet/src/data_callbacks.cc new file mode 100644 index 0000000000..1fd8ef9a05 --- /dev/null +++ b/mindspore/lite/examples/train_lenet/src/data_callbacks.cc @@ -0,0 +1,103 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include "src/net_runner.h" +#include "include/context.h" +#include "src/utils.h" +#include "src/data_loader.h" +#include "src/accuracy_monitor.h" + +static unsigned int seed = time(NULL); + +std::vector FillInputDataUtil(const mindspore::session::TrainLoopCallBackData &cb_data, + const std::vector &dataset, bool serially) { + static unsigned int idx = 1; + int total_size = dataset.size(); + std::vector labels_vec; + + auto inputs = cb_data.session_->GetInputs(); + char *input_data = reinterpret_cast(inputs.at(0)->MutableData()); + auto labels = reinterpret_cast(inputs.at(1)->MutableData()); + int batch_size = inputs.at(0)->shape()[0]; + int num_of_classes = inputs.at(1)->shape()[1]; + int data_size = inputs.at(0)->Size() / batch_size; + MS_ASSERT(total_size > 0); + MS_ASSERT(input_data != nullptr); + std::fill(labels, labels + inputs.at(1)->ElementsNum(), 0.f); + for (int i = 0; i < batch_size; i++) { + if (serially) { + idx = ++idx % total_size; + } else { + idx = rand_r(&seed) % total_size; + } + int label = 0; + char *data = nullptr; + std::tie(data, label) = dataset[idx]; + std::copy(data, data + data_size, input_data + i * data_size); + labels[i * num_of_classes + label] = 1.0; // Model expects labels in onehot representation + labels_vec.push_back(label); + } + return labels_vec; +} + +void DataLoader::StepBegin(const mindspore::session::TrainLoopCallBackData &cb_data) { + FillInputDataUtil(cb_data, ds_->train_data(), false); +} + +int AccuracyMonitor::EpochEnd(const mindspore::session::TrainLoopCallBackData &cb_data) { + if ((cb_data.epoch_ + 1) % check_every_n_ != 0) return mindspore::session::RET_CONTINUE; + + float accuracy = 0.0; + auto inputs = cb_data.session_->GetInputs(); + int batch_size = inputs.at(0)->shape()[0]; + int num_of_classes = ds_->num_of_classes(); + int tests = ds_->test_data().size() / batch_size; + if (max_steps_ != -1 && tests > max_steps_) tests = max_steps_; + cb_data.session_->Eval(); + for (int i = 0; i < tests; i++) { + auto labels = FillInputDataUtil(cb_data, ds_->test_data(), false); + cb_data.session_->RunGraph(); + auto outputs = cb_data.session_->GetPredictions(); + for (auto it = outputs.begin(); it != outputs.end(); ++it) { + if (it->second->ElementsNum() == batch_size * num_of_classes) { + auto scores = reinterpret_cast(it->second->MutableData()); + for (int b = 0; b < batch_size; b++) { + int max_idx = 0; + float max_score = scores[num_of_classes * b]; + for (int c = 1; c < num_of_classes; c++) { + if (scores[num_of_classes * b + c] > max_score) { + max_score = scores[num_of_classes * b + c]; + max_idx = c; + } + } + if (labels[b] == max_idx) accuracy += 1.0; + } + break; + } + } + } + accuracy /= static_cast(batch_size * tests); + accuracies_.push_back(std::make_pair(cb_data.epoch_, accuracy)); + std::cout << cb_data.epoch_ + 1 << ":\tAccuracy is " << accuracy << std::endl; + cb_data.session_->Train(); + return mindspore::session::RET_CONTINUE; +} diff --git a/mindspore/lite/examples/train_lenet/src/data_loader.h b/mindspore/lite/examples/train_lenet/src/data_loader.h new file mode 100644 index 0000000000..884c783121 --- /dev/null +++ b/mindspore/lite/examples/train_lenet/src/data_loader.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_DATA_LOADER_H_ +#define MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_DATA_LOADER_H_ +#include +#include +#include +#include +#include "include/train/train_loop.h" +#include "src/dataset.h" + +class DataLoader : public mindspore::session::TrainLoopCallBack { + public: + explicit DataLoader(DataSet *dataset) : ds_(dataset) {} + void StepBegin(const mindspore::session::TrainLoopCallBackData &cb_data) override; + + private: + DataSet *ds_; +}; + +#endif // MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_DATA_LOADER_H_ diff --git a/mindspore/lite/examples/train_lenet/src/net_runner.cc b/mindspore/lite/examples/train_lenet/src/net_runner.cc index 8c343bcb6a..fa9abd78bc 100644 --- a/mindspore/lite/examples/train_lenet/src/net_runner.cc +++ b/mindspore/lite/examples/train_lenet/src/net_runner.cc @@ -20,10 +20,21 @@ #include #include #include +#include #include "include/context.h" +#include "include/train/loss_monitor.h" +#include "include/train/ckpt_saver.h" +#include "include/train/lr_scheduler.h" +#include "include/train/classification_train_accuracy_monitor.h" #include "src/utils.h" +#include "src/data_loader.h" +#include "src/accuracy_monitor.h" + +using mindspore::session::TrainLoopCallBack; +using mindspore::session::TrainLoopCallBackData; + +static unsigned int seed = time(NULL); -unsigned int NetRunner::seed_ = time(NULL); // Definition of callback function after forwarding operator. bool after_callback(const std::vector &after_inputs, const std::vector &after_outputs, @@ -54,15 +65,18 @@ bool after_callback(const std::vector &after_inpu } NetRunner::~NetRunner() { - if (session_ != nullptr) delete session_; + if (loop_ != nullptr) delete loop_; } void NetRunner::InitAndFigureInputs() { mindspore::lite::Context context; context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = mindspore::lite::NO_BIND; - context.thread_num_ = 1; + context.device_list_[0].device_info_.cpu_device_info_.enable_float16_ = false; + context.device_list_[0].device_type_ = mindspore::lite::DT_CPU; + context.thread_num_ = 2; - session_ = mindspore::session::TrainSession::CreateSession(ms_file_, &context); + loop_ = mindspore::session::TrainLoop::CreateTrainLoop(ms_file_, &context); + session_ = loop_->train_session(); MS_ASSERT(nullptr != session_); auto inputs = session_->GetInputs(); @@ -76,71 +90,10 @@ void NetRunner::InitAndFigureInputs() { } } -mindspore::tensor::MSTensor *NetRunner::SearchOutputsForSize(size_t size) const { - auto outputs = session_->GetOutputs(); - for (auto it = outputs.begin(); it != outputs.end(); ++it) { - if (it->second->ElementsNum() == size) return it->second; - } - std::cout << "Model does not have an output tensor with size " << size << std::endl; - return nullptr; -} - -std::vector NetRunner::FillInputData(const std::vector &dataset, bool serially) const { - std::vector labels_vec; - static unsigned int idx = 1; - int total_size = dataset.size(); - - auto inputs = session_->GetInputs(); - char *input_data = reinterpret_cast(inputs.at(data_index_)->MutableData()); - auto labels = reinterpret_cast(inputs.at(label_index_)->MutableData()); - MS_ASSERT(total_size > 0); - MS_ASSERT(input_data != nullptr); - std::fill(labels, labels + inputs.at(label_index_)->ElementsNum(), 0.f); - for (int i = 0; i < batch_size_; i++) { - if (serially) { - idx = ++idx % total_size; - } else { - idx = rand_r(&seed_) % total_size; - } - int label = 0; - char *data = nullptr; - std::tie(data, label) = dataset[idx]; - std::memcpy(input_data + i * data_size_, data, data_size_); - labels[i * num_of_classes_ + label] = 1.0; // Model expects labels in onehot representation - labels_vec.push_back(label); - } - - return labels_vec; -} - -float NetRunner::CalculateAccuracy(int max_tests) const { - float accuracy = 0.0; - const std::vector test_set = ds_.test_data(); - int tests = test_set.size() / batch_size_; - if (max_tests != -1 && tests < max_tests) tests = max_tests; - - session_->Eval(); - for (int i = 0; i < tests; i++) { - auto labels = FillInputData(test_set, (max_tests == -1)); - session_->RunGraph(); - auto outputsv = SearchOutputsForSize(batch_size_ * num_of_classes_); - MS_ASSERT(outputsv != nullptr); - auto scores = reinterpret_cast(outputsv->MutableData()); - for (int b = 0; b < batch_size_; b++) { - int max_idx = 0; - float max_score = scores[num_of_classes_ * b]; - for (int c = 0; c < num_of_classes_; c++) { - if (scores[num_of_classes_ * b + c] > max_score) { - max_score = scores[num_of_classes_ * b + c]; - max_idx = c; - } - } - if (labels[b] == max_idx) accuracy += 1.0; - } - } - session_->Train(); - accuracy /= static_cast(batch_size_ * tests); - return accuracy; +float NetRunner::CalculateAccuracy(int max_tests) { + AccuracyMonitor test_am(&ds_, 1, max_tests); + test_am.EpochEnd(TrainLoopCallBackData(true, 0, session_, loop_)); + return 0.0; } int NetRunner::InitDB() { @@ -155,35 +108,17 @@ int NetRunner::InitDB() { return ret; } -float NetRunner::GetLoss() const { - auto outputsv = SearchOutputsForSize(1); // Search for Loss which is a single value tensor - MS_ASSERT(outputsv != nullptr); - auto loss = reinterpret_cast(outputsv->MutableData()); - return loss[0]; -} - int NetRunner::TrainLoop() { - session_->Train(); - float min_loss = 1000.; - float max_acc = 0.; - for (int i = 0; i < cycles_; i++) { - FillInputData(ds_.train_data()); - session_->RunGraph(nullptr, verbose_ ? after_callback : nullptr); - float loss = GetLoss(); - if (min_loss > loss) min_loss = loss; - - if (save_checkpoint_ != 0 && (i + 1) % save_checkpoint_ == 0) { - auto cpkt_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained_" + std::to_string(i + 1) + ".ms"; - session_->SaveToFile(cpkt_fn); - } + struct mindspore::lite::StepLRLambda step_lr_lambda(100, 0.9); + mindspore::lite::LRScheduler step_lr_sched(mindspore::lite::StepLRLambda, static_cast(&step_lr_lambda), 100); - if ((i + 1) % 100 == 0) { - float acc = CalculateAccuracy(10); - if (max_acc < acc) max_acc = acc; - std::cout << i + 1 << ":\tLoss is " << std::setw(7) << loss << " [min=" << min_loss << "] " - << " max_acc=" << max_acc << std::endl; - } - } + mindspore::lite::LossMonitor lm(100); + // mindspore::lite::ClassificationTrainAccuracyMonitor am(10); + mindspore::lite::CkptSaver cs(1000, std::string("lenet")); + AccuracyMonitor test_am(&ds_, 500, 10); + DataLoader dl(&ds_); + + loop_->Train(cycles_, std::vector{&dl, &lm, &test_am, &cs, &step_lr_sched}); return 0; } @@ -194,8 +129,7 @@ int NetRunner::Main() { TrainLoop(); - float acc = CalculateAccuracy(); - std::cout << "accuracy = " << acc << std::endl; + CalculateAccuracy(); if (cycles_ > 0) { auto trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained_" + std::to_string(cycles_) + ".ms"; diff --git a/mindspore/lite/examples/train_lenet/src/net_runner.h b/mindspore/lite/examples/train_lenet/src/net_runner.h index 8687d15e0f..33f1aaa5b9 100644 --- a/mindspore/lite/examples/train_lenet/src/net_runner.h +++ b/mindspore/lite/examples/train_lenet/src/net_runner.h @@ -23,6 +23,7 @@ #include #include #include "include/train_session.h" +#include "include/train/train_loop.h" #include "include/ms_tensor.h" #include "src/dataset.h" @@ -38,12 +39,13 @@ class NetRunner { int InitDB(); int TrainLoop(); std::vector FillInputData(const std::vector &dataset, bool is_train_set = false) const; - float CalculateAccuracy(int max_tests = -1) const; + float CalculateAccuracy(int max_tests = -1); float GetLoss() const; mindspore::tensor::MSTensor *SearchOutputsForSize(size_t size) const; DataSet ds_; mindspore::session::TrainSession *session_ = nullptr; + mindspore::session::TrainLoop *loop_ = nullptr; std::string ms_file_ = ""; std::string data_dir_ = ""; diff --git a/mindspore/lite/examples/transfer_learning/src/net_runner.cc b/mindspore/lite/examples/transfer_learning/src/net_runner.cc index d8128dcccb..74dd2f9cdd 100644 --- a/mindspore/lite/examples/transfer_learning/src/net_runner.cc +++ b/mindspore/lite/examples/transfer_learning/src/net_runner.cc @@ -17,9 +17,10 @@ #include "src/net_runner.h" #include #include +#include #include -#include #include +#include #include "include/context.h" #include "src/utils.h" @@ -113,7 +114,7 @@ std::vector NetRunner::FillInputData(const std::vector &dat int label = 0; char *data = nullptr; std::tie(data, label) = dataset[idx]; - std::memcpy(input_data + i * data_size_, data, data_size_); + std::copy(data, data + data_size, input_data + i * data_size); labels[i * num_of_classes_ + label] = 1.0; // Model expects labels in onehot representation labels_vec.push_back(label); } diff --git a/mindspore/lite/include/train/ckpt_saver.h b/mindspore/lite/include/train/ckpt_saver.h new file mode 100644 index 0000000000..5582167b20 --- /dev/null +++ b/mindspore/lite/include/train/ckpt_saver.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_CKPT_SAVER_H_ +#define MINDSPORE_LITE_INCLUDE_TRAIN_CKPT_SAVER_H_ +#include +#include +#include +#include +#include +#include "include/train/train_loop.h" + +using GraphPoint = std::pair; + +namespace mindspore { +namespace lite { + +class CkptSaver : public session::TrainLoopCallBack { + public: + CkptSaver(int save_every_n, const std::string &filename_prefix) + : save_every_n_(save_every_n), filename_prefix_(filename_prefix) {} + + int EpochEnd(const session::TrainLoopCallBackData &cb_data) override { + if ((cb_data.epoch_ + 1) % save_every_n_ == 0) { + auto cpkt_fn = filename_prefix_ + "_trained_" + std::to_string(cb_data.epoch_ + 1) + ".ms"; + remove(cpkt_fn.c_str()); + cb_data.session_->SaveToFile(cpkt_fn); + } + return session::RET_CONTINUE; + } + + private: + int save_every_n_; + std::string filename_prefix_; +}; + +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_INCLUDE_TRAIN_CKPT_SAVER_H_ diff --git a/mindspore/lite/include/train/classification_train_accuracy_monitor.h b/mindspore/lite/include/train/classification_train_accuracy_monitor.h new file mode 100644 index 0000000000..df44b0ceb3 --- /dev/null +++ b/mindspore/lite/include/train/classification_train_accuracy_monitor.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_CLASSIFICATION_TRAIN_ACCURACY_MONITOR_H_ +#define MINDSPORE_LITE_INCLUDE_TRAIN_CLASSIFICATION_TRAIN_ACCURACY_MONITOR_H_ +#include +#include +#include +#include +#include +#include "include/train/train_loop.h" + +using GraphPoint = std::pair; + +namespace mindspore { +namespace lite { + +class ClassificationTrainAccuracyMonitor : public session::TrainLoopCallBack { + public: + explicit ClassificationTrainAccuracyMonitor(int print_every_n = INT_MAX) : print_every_n_(print_every_n) {} + + virtual ~ClassificationTrainAccuracyMonitor() = default; + void Begin(const session::TrainLoopCallBackData &cb_data) override; + void EpochBegin(const session::TrainLoopCallBackData &cb_data) override; + int EpochEnd(const session::TrainLoopCallBackData &cb_data) override; + void StepEnd(const session::TrainLoopCallBackData &cb_data) override; + const std::vector &GetAccuracyPoints() const { return accuracies_; } + + private: + std::vector accuracies_; + int print_every_n_ = 0; +}; + +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_INCLUDE_TRAIN_CLASSIFICATION_TRAIN_ACCURACY_MONITOR_H_ diff --git a/mindspore/lite/include/train/loss_monitor.h b/mindspore/lite/include/train/loss_monitor.h new file mode 100644 index 0000000000..824793761c --- /dev/null +++ b/mindspore/lite/include/train/loss_monitor.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_LOSS_MONITOR_H_ +#define MINDSPORE_LITE_INCLUDE_TRAIN_LOSS_MONITOR_H_ +#include +#include +#include +#include +#include +#include "include/train/train_loop_callback.h" + +using GraphPoint = std::pair; + +namespace mindspore { +namespace lite { + +class LossMonitor : public session::TrainLoopCallBack { + public: + explicit LossMonitor(int print_every_n = INT_MAX) : print_every_n_(print_every_n) {} + virtual ~LossMonitor() = default; + void Begin(const session::TrainLoopCallBackData &cb_data) override; + void EpochBegin(const session::TrainLoopCallBackData &cb_data) override; + int EpochEnd(const session::TrainLoopCallBackData &cb_data) override; + void StepEnd(const session::TrainLoopCallBackData &cb_data) override; + const std::vector &GetLossPoints() const { return losses_; } + + private: + std::vector losses_; + int print_every_n_; +}; + +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_INCLUDE_TRAIN_LOSS_MONITOR_H_ diff --git a/mindspore/lite/include/train/lr_scheduler.h b/mindspore/lite/include/train/lr_scheduler.h new file mode 100644 index 0000000000..25a392145d --- /dev/null +++ b/mindspore/lite/include/train/lr_scheduler.h @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_LR_SCHEDULER_H_ +#define MINDSPORE_LITE_INCLUDE_TRAIN_LR_SCHEDULER_H_ +#include +#include +#include +#include +#include +#include "include/train/train_loop_callback.h" + +namespace mindspore { +namespace lite { + +constexpr int DONT_UPDATE_LR = 0; +constexpr int UPDATE_LR = 1; + +using LR_Lambda = std::function; + +/// \brief Multiply the LR by a factor of gamma every epoch +int MultiplicativeLRLambda(float *lr, int epoch, void *multiplication); + +/// \brief Multiply the LR by a factor of gamma every step_size +int StepLRLambda(float *lr, int epoch, void *step_size); +struct StepLRLambda { + StepLRLambda(int step, float g) : step_size(step), gamma(g) {} + + int step_size; // period of LR decay + float gamma; // LR decay factor +}; + +class LRScheduler : public session::TrainLoopCallBack { + public: + explicit LRScheduler(LR_Lambda lambda_func, void *lr_cb_data = nullptr, int step_ = 1); + virtual ~LRScheduler() = default; + int EpochEnd(const session::TrainLoopCallBackData &cb_data) override; + + private: + LR_Lambda lambda_func_; + void *lr_data_ = nullptr; + int step_ = 1; +}; + +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_INCLUDE_TRAIN_LR_SCHEDULER_H_ diff --git a/mindspore/lite/include/train/train_loop.h b/mindspore/lite/include/train/train_loop.h new file mode 100644 index 0000000000..c5f59f1fa9 --- /dev/null +++ b/mindspore/lite/include/train/train_loop.h @@ -0,0 +1,69 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_H_ +#define MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_H_ +#include +#include +#include +#include +#include "include/train/train_loop_callback.h" +#include "include/train_session.h" + +namespace mindspore { +namespace session { + +class TrainLoop { + public: + /// \brief Static method to create a TrainLoop object + /// + /// \param[in] filename Filename to read flatbuffer from + /// \param[in] context Defines the context of the session to be created + /// + /// \return Pointer of MindSpore Lite TrainLoop + static TrainLoop *CreateTrainLoop(const std::string &model_filename, lite::Context *context, int batch_size = -1); + + /// \brief Class destructor + virtual ~TrainLoop() = default; + + /// \brief Resets the epoch counter + /// + /// \return 0 on success or -1 in case of error + virtual int Reset() = 0; // resets the epoch counter to 0. + + /// \brief Accessor to the TrainSession + /// + /// \return pointer of the train_session + virtual session::TrainSession *train_session() = 0; + + /// \brief Accessor to the Session KernelCallbacks + /// + /// \param[in] before Define a call_back_function to be called before running each node. + /// \param[in] after Define a call_back_function called after running each node. + /// + /// \return 0 on success or -1 in case of error + virtual int SetKernelCallBack(const KernelCallBack &before, const KernelCallBack &after) = 0; + + /// \brief Performs the training Loop + /// + /// \param[in] epoch The number of epochs to run + /// \param[in] cbs A vector of TrainLoopCallBack objects + /// + /// \return 0 on success or -1 in case of error + virtual int Train(int epochs, std::vector cbs) = 0; +}; +} // namespace session +} // namespace mindspore +#endif // MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_H_ diff --git a/mindspore/lite/include/train/train_loop_callback.h b/mindspore/lite/include/train/train_loop_callback.h new file mode 100644 index 0000000000..4c17ac1d40 --- /dev/null +++ b/mindspore/lite/include/train/train_loop_callback.h @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_CALLBACK_H_ +#define MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_CALLBACK_H_ +#include +#include +#include +#include + +namespace mindspore { +namespace session { + +class TrainSession; +class TrainLoop; + +struct TrainLoopCallBackData { + TrainLoopCallBackData(bool train_mode, int epoch, TrainSession *session, TrainLoop *loop) + : train_mode_(train_mode), epoch_(epoch), session_(session), loop_(loop) {} + + bool train_mode_; /**< training mode of TrainSession object */ + unsigned int epoch_; /**< the current training epoch (starts at 0) */ + unsigned int step_ = 0; /**< the current step within the epoch */ + TrainSession *session_; /**< pointer to the TrainSession */ + TrainLoop *loop_; +}; + +constexpr int RET_CONTINUE = 0; +constexpr int RET_STOP_TRAINING = 1; +constexpr int RET_EXIT = 2; + +class TrainLoopCallBack { + public: + virtual ~TrainLoopCallBack() = default; + virtual void Begin(const TrainLoopCallBackData &cb_data) {} + virtual void End(const TrainLoopCallBackData &cb_data) {} + virtual void EpochBegin(const TrainLoopCallBackData &cb_data) {} + virtual int EpochEnd(const TrainLoopCallBackData &cb_data) { return RET_CONTINUE; } + virtual void StepBegin(const TrainLoopCallBackData &cb_data) {} + virtual void StepEnd(const TrainLoopCallBackData &cb_data) {} +}; + +} // namespace session +} // namespace mindspore +#endif // MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_CALLBACK_H_ diff --git a/mindspore/lite/include/train_session.h b/mindspore/lite/include/train_session.h index 0a7faf440a..aa39950393 100644 --- a/mindspore/lite/include/train_session.h +++ b/mindspore/lite/include/train_session.h @@ -83,6 +83,23 @@ class TrainSession : public session::LiteSession { /// \return boolean indication if model is in eval mode bool IsEval() { return train_mode_ == false; } + /// \brief Sets the Learning Rate of the training + /// + /// \param[in] learning_rate to set + /// + /// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h + virtual int SetLearningRate(float learning_rate) = 0; + + /// \brief Gets the Learning Rate of the training + /// + /// \return learning rate. 0.0 if no optimizer was found + virtual float GetLearningRate() = 0; + + /// \brief Get output MindSpore Lite MSTensors of Training model prediction + /// + /// \return The map of output tensor name and MindSpore Lite MSTensor. + virtual std::unordered_map GetPredictions() const = 0; + protected: bool train_mode_ = false; }; diff --git a/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/TrainSession.java b/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/TrainSession.java new file mode 100644 index 0000000000..536d519b7e --- /dev/null +++ b/mindspore/lite/java/java/app/src/main/java/com/mindspore/lite/TrainSession.java @@ -0,0 +1,178 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mindspore.lite; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import com.mindspore.lite.config.MSConfig; + +public class TrainSession { + static { + System.loadLibrary("mindspore-lite-jni"); + } + + private long sessionPtr; + + public TrainSession() { + this.sessionPtr = 0; + } + + public boolean init(String modelFilename, MSConfig config) { + this.sessionPtr = createSession(modelFilename, config.getMSConfigPtr()); + return this.sessionPtr != 0; + } + + public long getSessionPtr() { + return sessionPtr; + } + + public void bindThread(boolean if_bind) { + this.bindThread(this.sessionPtr, if_bind); + } + + public boolean runGraph() { + return this.runGraph(this.sessionPtr); + } + + public List getInputs() { + List ret = this.getInputs(this.sessionPtr); + ArrayList tensors = new ArrayList(); + for (Long ms_tensor_addr : ret) { + MSTensor msTensor = new MSTensor(ms_tensor_addr); + tensors.add(msTensor); + } + return tensors; + } + + public MSTensor getInputsByTensorName(String tensorName) { + Long tensor_addr = this.getInputsByTensorName(this.sessionPtr, tensorName); + if(tensor_addr == null){ + return null; + } + MSTensor msTensor = new MSTensor(tensor_addr); + return msTensor; + } + + public List getOutputsByNodeName(String nodeName) { + List ret = this.getOutputsByNodeName(this.sessionPtr, nodeName); + ArrayList tensors = new ArrayList<>(); + for (Long msTensorAddr : ret) { + MSTensor msTensor = new MSTensor(msTensorAddr); + tensors.add(msTensor); + } + return tensors; + } + + public Map getOutputMapByTensor() { + Map ret = this.getOutputMapByTensor(this.sessionPtr); + Map tensorMap = new HashMap<>(); + Set> entrySet = ret.entrySet(); + for (Map.Entry entry : entrySet) { + String name = entry.getKey(); + Long msTensorAddr = entry.getValue(); + tensorMap.put(name, new MSTensor(msTensorAddr)); + } + return tensorMap; + } + + public List getOutputTensorNames() { + return getOutputTensorNames(this.sessionPtr); + } + + public MSTensor getOutputByTensorName(String tensorName) { + Long tensor_addr = getOutputByTensorName(this.sessionPtr, tensorName); + if(tensor_addr == null){ + return null; + } + return new MSTensor(tensor_addr); + } + + public void free() { + this.free(this.sessionPtr); + this.sessionPtr = 0; + } + + public boolean resize(List inputs, int[][] dims) { + long[] inputs_array = new long[inputs.size()]; + for (int i = 0; i < inputs.size(); i++) { + inputs_array[i] = inputs.get(i).getMSTensorPtr(); + } + return this.resize(this.sessionPtr, inputs_array, dims); + } + + public boolean saveToFile(String modelFilename) { + return this.saveToFile(this.sessionPtr, modelFilename); + } + + public boolean train() { + return this.train(this.sessionPtr); + } + + public boolean eval() { + return this.eval(this.sessionPtr); + } + + public boolean isTrain() { + return this.isTrain(this.sessionPtr); + } + + public boolean isEval() { + return this.isEval(this.sessionPtr); + } + + public boolean setLearningRate(float learning_rate) { + return this.setLearningRate(this.sessionPtr, learning_rate); + } + + private native long createSession(String modelFilename, long msConfigPtr); + + private native void bindThread(long sessionPtr, boolean if_bind); + + private native boolean runGraph(long sessionPtr); + + private native List getInputs(long sessionPtr); + + private native long getInputsByTensorName(long sessionPtr, String tensorName); + + private native List getOutputsByNodeName(long sessionPtr, String nodeName); + + private native Map getOutputMapByTensor(long sessionPtr); + + private native List getOutputTensorNames(long sessionPtr); + + private native long getOutputByTensorName(long sessionPtr, String tensorName); + + private native void free(long sessionPtr); + + private native boolean resize(long sessionPtr, long[] inputs, int[][] dims); + + private native boolean saveToFile(long sessionPtr, String modelFilename); + + private native boolean train(long sessionPtr); + + private native boolean eval(long sessionPtr); + + private native boolean isTrain(long sessionPtr); + + private native boolean isEval(long sessionPtr); + + private native boolean setLearningRate(long sessionPtr, float learning_rate); +} diff --git a/mindspore/lite/java/java/app/src/main/native/CMakeLists.txt b/mindspore/lite/java/java/app/src/main/native/CMakeLists.txt index 1afbbe8274..d51c1d34b2 100644 --- a/mindspore/lite/java/java/app/src/main/native/CMakeLists.txt +++ b/mindspore/lite/java/java/app/src/main/native/CMakeLists.txt @@ -7,8 +7,10 @@ set(PLATFORM_ARM "on") set(MS_VERSION_MAJOR ${MS_VERSION_MAJOR}) set(MS_VERSION_MINOR ${MS_VERSION_MINOR}) set(MS_VERSION_REVISION ${MS_VERSION_REVISION}) -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}") +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} \ + -DMS_VERSION_REVISION=${MS_VERSION_REVISION}") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} \ + -DMS_VERSION_REVISION=${MS_VERSION_REVISION}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17") #set for cross-compiling toolchain @@ -16,16 +18,16 @@ set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH) set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE BOTH) set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE BOTH) -if (ENABLE_VERBOSE) +if(ENABLE_VERBOSE) set(CMAKE_VERBOSE_MAKEFILE on) -endif () +endif() -if (PLATFORM_ARM32) +if(PLATFORM_ARM32) add_compile_definitions(ENABLE_ARM32) -endif () -if (PLATFORM_ARM64) +endif() +if(PLATFORM_ARM64) add_compile_definitions(ENABLE_ARM64) -endif () +endif() set(TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../../../../../..) set(LITE_DIR ${TOP_DIR}/mindspore/lite) @@ -40,15 +42,25 @@ include_directories(${LITE_DIR}/build) ## flatbuffers link_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../../libs/${ANDROID_ABI}/) -add_library(mindspore-lite-jni SHARED - ${CMAKE_CURRENT_SOURCE_DIR}/common/jni_utils.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/runtime/model.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/runtime/version.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/runtime/ms_config.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/runtime/ms_tensor.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/runtime/lite_session.cpp - ) +set(JNI_SRC + ${CMAKE_CURRENT_SOURCE_DIR}/common/jni_utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/model.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/version.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/ms_config.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/ms_tensor.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/lite_session.cpp + ) + +if(SUPPORT_TRAIN) + set(JNI_SRC + ${JNI_SRC} + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp + ) +endif() + +add_library(mindspore-lite-jni SHARED ${JNI_SRC}) + find_library(log-lib log) -target_link_libraries(mindspore-lite-jni mindspore-lite ${log-lib}) \ No newline at end of file +target_link_libraries(mindspore-lite-jni mindspore-lite ${log-lib}) diff --git a/mindspore/lite/java/java/app/src/main/native/runtime/train_session.cpp b/mindspore/lite/java/java/app/src/main/native/runtime/train_session.cpp new file mode 100644 index 0000000000..c37d734c7e --- /dev/null +++ b/mindspore/lite/java/java/app/src/main/native/runtime/train_session.cpp @@ -0,0 +1,305 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "common/ms_log.h" +#include "common/jni_utils.h" +#include "include/train_session.h" +#include "include/errorcode.h" + +extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_TrainSession_createSession(JNIEnv *env, jobject thiz, + jstring model_file_name, + jlong ms_config_ptr) { + auto *pointer = reinterpret_cast(ms_config_ptr); + if (pointer == nullptr) { + MS_LOGE("Context pointer from java is nullptr"); + return jlong(nullptr); + } + auto *lite_context_ptr = static_cast(pointer); + auto session = mindspore::session::TrainSession::CreateSession(JstringToChar(env, model_file_name), lite_context_ptr); + if (session == nullptr) { + MS_LOGE("CreateSession failed"); + return jlong(nullptr); + } + return jlong(session); +} + +extern "C" JNIEXPORT void JNICALL Java_com_mindspore_lite_TrainSession_bindThread(JNIEnv *env, jobject thiz, + jlong session_ptr, jboolean if_bind) { + auto *pointer = reinterpret_cast(session_ptr); + if (pointer == nullptr) { + MS_LOGE("Session pointer from java is nullptr"); + return; + } + auto *train_session_ptr = static_cast(pointer); + train_session_ptr->BindThread(if_bind); +} + +extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_runGraph(JNIEnv *env, jobject thiz, + jlong session_ptr) { + auto *pointer = reinterpret_cast(session_ptr); + if (pointer == nullptr) { + MS_LOGE("Session pointer from java is nullptr"); + return (jboolean) false; + } + auto *train_session_ptr = static_cast(pointer); + auto ret = train_session_ptr->RunGraph(); + return (jboolean)(ret == mindspore::lite::RET_OK); +} + +extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_TrainSession_getInputs(JNIEnv *env, jobject thiz, + jlong session_ptr) { + jclass array_list = env->FindClass("java/util/ArrayList"); + jmethodID array_list_construct = env->GetMethodID(array_list, "", "()V"); + jobject ret = env->NewObject(array_list, array_list_construct); + jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z"); + + jclass long_object = env->FindClass("java/lang/Long"); + jmethodID long_object_construct = env->GetMethodID(long_object, "", "(J)V"); + auto *pointer = reinterpret_cast(session_ptr); + if (pointer == nullptr) { + MS_LOGE("Session pointer from java is nullptr"); + return ret; + } + auto *train_session_ptr = static_cast(pointer); + auto inputs = train_session_ptr->GetInputs(); + for (auto input : inputs) { + jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input)); + env->CallBooleanMethod(ret, array_list_add, tensor_addr); + } + return ret; +} + +extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_TrainSession_getInputsByTensorName(JNIEnv *env, jobject thiz, + jlong session_ptr, + jstring tensor_name) { + auto *pointer = reinterpret_cast(session_ptr); + if (pointer == nullptr) { + MS_LOGE("Session pointer from java is nullptr"); + return jlong(nullptr); + } + auto *train_session_ptr = static_cast(pointer); + auto input = train_session_ptr->GetInputsByTensorName(JstringToChar(env, tensor_name)); + return jlong(input); +} + +extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_TrainSession_getOutputsByNodeName(JNIEnv *env, + jobject thiz, + jlong session_ptr, + jstring node_name) { + jclass array_list = env->FindClass("java/util/ArrayList"); + jmethodID array_list_construct = env->GetMethodID(array_list, "", "()V"); + jobject ret = env->NewObject(array_list, array_list_construct); + jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z"); + + jclass long_object = env->FindClass("java/lang/Long"); + jmethodID long_object_construct = env->GetMethodID(long_object, "", "(J)V"); + auto *pointer = reinterpret_cast(session_ptr); + if (pointer == nullptr) { + MS_LOGE("Session pointer from java is nullptr"); + return ret; + } + auto *train_session_ptr = static_cast(pointer); + auto inputs = train_session_ptr->GetOutputsByNodeName(JstringToChar(env, node_name)); + for (auto input : inputs) { + jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(input)); + env->CallBooleanMethod(ret, array_list_add, tensor_addr); + } + return ret; +} + +extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_TrainSession_getOutputMapByTensor(JNIEnv *env, + jobject thiz, + jlong session_ptr) { + jclass hash_map_clazz = env->FindClass("java/util/HashMap"); + jmethodID hash_map_construct = env->GetMethodID(hash_map_clazz, "", "()V"); + jobject hash_map = env->NewObject(hash_map_clazz, hash_map_construct); + jmethodID hash_map_put = + env->GetMethodID(hash_map_clazz, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;"); + auto *pointer = reinterpret_cast(session_ptr); + if (pointer == nullptr) { + MS_LOGE("Session pointer from java is nullptr"); + return hash_map; + } + auto *train_session_ptr = static_cast(pointer); + auto outputs = train_session_ptr->GetOutputs(); + jclass long_object = env->FindClass("java/lang/Long"); + jmethodID long_object_construct = env->GetMethodID(long_object, "", "(J)V"); + for (auto output_iter : outputs) { + auto node_name = output_iter.first; + auto ms_tensor = output_iter.second; + jobject tensor_addr = env->NewObject(long_object, long_object_construct, jlong(ms_tensor)); + env->CallObjectMethod(hash_map, hash_map_put, env->NewStringUTF(node_name.c_str()), tensor_addr); + } + return hash_map; +} + +extern "C" JNIEXPORT jobject JNICALL Java_com_mindspore_lite_TrainSession_getOutputTensorNames(JNIEnv *env, + jobject thiz, + jlong session_ptr) { + jclass array_list = env->FindClass("java/util/ArrayList"); + jmethodID array_list_construct = env->GetMethodID(array_list, "", "()V"); + jobject ret = env->NewObject(array_list, array_list_construct); + jmethodID array_list_add = env->GetMethodID(array_list, "add", "(Ljava/lang/Object;)Z"); + + auto *pointer = reinterpret_cast(session_ptr); + if (pointer == nullptr) { + MS_LOGE("Session pointer from java is nullptr"); + return ret; + } + auto *train_session_ptr = static_cast(pointer); + auto output_names = train_session_ptr->GetOutputTensorNames(); + for (auto output_name : output_names) { + env->CallBooleanMethod(ret, array_list_add, env->NewStringUTF(output_name.c_str())); + } + return ret; +} + +extern "C" JNIEXPORT jlong JNICALL Java_com_mindspore_lite_TrainSession_getOutputByTensorName(JNIEnv *env, jobject thiz, + jlong session_ptr, + jstring tensor_name) { + auto *pointer = reinterpret_cast(session_ptr); + if (pointer == nullptr) { + MS_LOGE("Session pointer from java is nullptr"); + return jlong(nullptr); + } + auto *train_session_ptr = static_cast(pointer); + auto output = train_session_ptr->GetOutputByTensorName(JstringToChar(env, tensor_name)); + return jlong(output); +} + +extern "C" JNIEXPORT void JNICALL Java_com_mindspore_lite_TrainSession_free(JNIEnv *env, jobject thiz, + jlong session_ptr) { + auto *pointer = reinterpret_cast(session_ptr); + if (pointer == nullptr) { + MS_LOGE("Session pointer from java is nullptr"); + return; + } + auto *train_session_ptr = static_cast(pointer); + delete (train_session_ptr); +} + +extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_resize(JNIEnv *env, jobject thiz, + jlong session_ptr, jlongArray inputs, + jobjectArray dims) { + std::vector> c_dims; + auto *pointer = reinterpret_cast(session_ptr); + if (pointer == nullptr) { + MS_LOGE("Session pointer from java is nullptr"); + return false; + } + auto *train_session_ptr = static_cast(pointer); + + jsize input_size = static_cast(env->GetArrayLength(inputs)); + jlong *input_data = env->GetLongArrayElements(inputs, nullptr); + std::vector c_inputs; + for (int i = 0; i < input_size; i++) { + auto *tensor_pointer = reinterpret_cast(input_data[i]); + if (tensor_pointer == nullptr) { + MS_LOGE("Tensor pointer from java is nullptr"); + return false; + } + auto *ms_tensor_ptr = static_cast(tensor_pointer); + c_inputs.push_back(ms_tensor_ptr); + } + jsize tensor_size = static_cast(env->GetArrayLength(dims)); + for (int i = 0; i < tensor_size; i++) { + jintArray array = static_cast(env->GetObjectArrayElement(dims, i)); + jsize dim_size = static_cast(env->GetArrayLength(array)); + jint *dim_data = env->GetIntArrayElements(array, nullptr); + std::vector tensor_dims; + for (int j = 0; j < dim_size; j++) { + tensor_dims.push_back(dim_data[j]); + } + c_dims.push_back(tensor_dims); + } + int ret = train_session_ptr->Resize(c_inputs, c_dims); + return (jboolean)(ret == mindspore::lite::RET_OK); +} + +extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_saveToFile(JNIEnv *env, jobject thiz, + jlong session_ptr, + jstring model_file_name) { + auto *session_pointer = reinterpret_cast(session_ptr); + if (session_pointer == nullptr) { + MS_LOGE("Session pointer from java is nullptr"); + return (jboolean) false; + } + auto *train_session_ptr = static_cast(session_pointer); + auto ret = train_session_ptr->SaveToFile(JstringToChar(env, model_file_name)); + return (jboolean)(ret == 0); +} + +extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_train(JNIEnv *env, jobject thiz, + jlong session_ptr) { + auto *session_pointer = reinterpret_cast(session_ptr); + if (session_pointer == nullptr) { + MS_LOGE("Session pointer from java is nullptr"); + return (jboolean) false; + } + auto *train_session_ptr = static_cast(session_pointer); + auto ret = train_session_ptr->Train(); + return (jboolean)(ret == mindspore::lite::RET_OK); +} + +extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_eval(JNIEnv *env, jobject thiz, + jlong session_ptr) { + auto *session_pointer = reinterpret_cast(session_ptr); + if (session_pointer == nullptr) { + MS_LOGE("Session pointer from java is nullptr"); + return (jboolean) false; + } + auto *train_session_ptr = static_cast(session_pointer); + auto ret = train_session_ptr->Eval(); + return (jboolean)(ret == mindspore::lite::RET_OK); +} + +extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_isTrain(JNIEnv *env, jobject thiz, + jlong session_ptr) { + auto *session_pointer = reinterpret_cast(session_ptr); + if (session_pointer == nullptr) { + MS_LOGE("Session pointer from java is nullptr"); + return (jboolean) false; + } + auto *train_session_ptr = static_cast(session_pointer); + auto ret = train_session_ptr->IsTrain(); + return (jboolean)(ret); +} + +extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_isEval(JNIEnv *env, jobject thiz, + jlong session_ptr) { + auto *session_pointer = reinterpret_cast(session_ptr); + if (session_pointer == nullptr) { + MS_LOGE("Session pointer from java is nullptr"); + return (jboolean) false; + } + auto *train_session_ptr = static_cast(session_pointer); + auto ret = train_session_ptr->IsEval(); + return (jboolean)(ret); +} + +extern "C" JNIEXPORT jboolean JNICALL Java_com_mindspore_lite_TrainSession_setLearningRate(JNIEnv *env, jobject thiz, + jlong session_ptr, + jfloat learning_rate) { + auto *session_pointer = reinterpret_cast(session_ptr); + if (session_pointer == nullptr) { + MS_LOGE("Session pointer from java is nullptr"); + return (jboolean) false; + } + auto *train_session_ptr = static_cast(session_pointer); + auto ret = train_session_ptr->SetLearningRate(learning_rate); + return (jboolean)(ret == mindspore::lite::RET_OK); +} diff --git a/mindspore/lite/nnacl/fp32/arithmetic_fp32.c b/mindspore/lite/nnacl/fp32/arithmetic_fp32.c index 6bcdc9dcdd..adf2ab37af 100644 --- a/mindspore/lite/nnacl/fp32/arithmetic_fp32.c +++ b/mindspore/lite/nnacl/fp32/arithmetic_fp32.c @@ -800,8 +800,8 @@ int ElementSubRelu6(const float *in0, const float *in1, float *out, int size) { int BroadcastDiv(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size, ArithmeticParameter *param) { - TileDimensionsFp32(in0, in1, tile_in0, tile_in0, param); - return ElementDiv(tile_in0, tile_in0, out, size); + TileDimensionsFp32(in0, in1, tile_in0, tile_in1, param); + return ElementDiv(tile_in0, tile_in1, out, size); } int ElementDiv(const float *in0, const float *in1, float *out, int size) { diff --git a/mindspore/lite/nnacl/fp32_grad/activation_grad.c b/mindspore/lite/nnacl/fp32_grad/activation_grad.c index 55b17f7c27..cb713a550b 100644 --- a/mindspore/lite/nnacl/fp32_grad/activation_grad.c +++ b/mindspore/lite/nnacl/fp32_grad/activation_grad.c @@ -21,32 +21,48 @@ #include "nnacl/errorcode.h" inline int ReluGrad(float *src0, float *src1, size_t length, float *dst) { - for (size_t i = 0; i < length; ++i) { - if (src1[i] > 0) { - dst[i] = src0[i]; - } else { - dst[i] = 0; - } + int i = 0; +#ifdef ENABLE_ARM + float32x4_t zero_4 = vdupq_n_f32(0.0f); + for (; i < length - 4; i += 4) { + float32x4_t src1_4 = vld1q_f32(src1 + i); + float32x4_t src0_4 = vld1q_f32(src0 + i); + uint32x4_t mask_4 = vcgtq_f32(src1_4, zero_4); + float32x4_t dst_4 = vbslq_f32(mask_4, src0_4, zero_4); + vst1q_f32(dst + i, dst_4); + } +#endif + for (; i < length; ++i) { + dst[i] = (src1[i] > 0.0f) ? src0[i] : 0.0f; } return NNACL_OK; } int Relu6Grad(float *src0, float *src1, size_t length, float *dst) { - for (size_t i = 0; i < length; ++i) { - if (src1[i] > 0.0f && src1[i] <= 6.0f) { - dst[i] = src0[i]; - } else { - dst[i] = 0.0f; - } + int i = 0; +#ifdef ENABLE_ARM + float32x4_t zero_4 = vdupq_n_f32(0.0f); + float32x4_t six_4 = vdupq_n_f32(6.0f); + for (; i < length - 4; i += 4) { + float32x4_t src1_4 = vld1q_f32(src1 + i); + float32x4_t src0_4 = vld1q_f32(src0 + i); + float32x4_t max_4 = vmaxq_f32(src1_4, zero_4); + float32x4_t min_max_4 = vminq_f32(max_4, six_4); + uint32x4_t mask_4 = vceqq_f32(min_max_4, src1_4); + float32x4_t dst_4 = vbslq_f32(mask_4, src0_4, zero_4); + vst1q_f32(dst + i, dst_4); + } +#endif + for (; i < length; ++i) { + dst[i] = (src1[i] > 0.0f && src1[i] <= 6.0f) ? src0[i] : 0.0f; } return NNACL_OK; } int LReluGrad(float *src0, float *src1, size_t length, float *dst, float alpha) { for (size_t i = 0; i < length; ++i) { - dst[i] = src1[i] > 0.0f ? 1.0f : alpha; + dst[i] = src1[i] > 0.0f ? src0[i] : alpha * src0[i]; } - ElementMul(src0, dst, dst, length); return NNACL_OK; } diff --git a/mindspore/lite/nnacl/fp32_grad/batch_norm.c b/mindspore/lite/nnacl/fp32_grad/batch_norm.c index d2f489dc52..add506cfb7 100644 --- a/mindspore/lite/nnacl/fp32_grad/batch_norm.c +++ b/mindspore/lite/nnacl/fp32_grad/batch_norm.c @@ -17,55 +17,36 @@ #include #include "nnacl/fp32_grad/batch_norm.h" -void sumSpatialBatch(const float *in, size_t size, int ch, float *out) { - memset(out, 0, ch * sizeof(float)); - for (size_t i = 0; i < size; i++) { - const float *ptr = in + (i * ch); - for (size_t c = 0; c < ch; c++) { - out[c] += ptr[c]; - } +void var2Invar(float *save_var, int size, float eps) { + for (int i = 0; i < size; i++) { + save_var[i] = 1.0f / sqrt(save_var[i] + eps); } } -void backwardX(const float *in, const float *dout, const float *scale, const size_t size, int channels, float *mean, - float *invar, float *dxhathat_sum, float *dxhat_sum, float *out) { - const float N = (size); - for (size_t i = 0; i < size; i++) { - for (size_t f = 0; f < channels; f++) { - size_t ix = i * channels + f; - float x_hat = (in[ix] - mean[f]) * invar[f]; - float dx_hat = dout[ix] * scale[f]; - dxhat_sum[f] += dx_hat; - dxhathat_sum[f] += dx_hat * x_hat; +void backwardAll(const float *restrict in, const float *restrict yt, const float *restrict mean, + const float *restrict invar, const float *restrict scale, int size, int ch, float *restrict dxhat_sum, + float *restrict dxhathat_sum, float *restrict dbias, float *restrict dscale, float *restrict dx) { + float N = (float)size; + for (int i = 0; i < size; i++) { + for (int c = 0; c < ch; c++) { + int ix = i * ch + c; + dbias[c] += yt[ix]; + // dscale + float x_hat = (in[ix] - mean[c]) * invar[c]; + dscale[c] += (yt[ix] * x_hat); + // dx_1 + float dx_hat = yt[ix] * scale[c]; + dxhat_sum[c] += dx_hat; + dxhathat_sum[c] += dx_hat * x_hat; } } - for (size_t i = 0; i < size; i++) { - for (size_t f = 0; f < channels; f++) { - size_t ix = i * channels + f; - float x_hat = (in[ix] - mean[f]) * invar[f]; - float dx_hat = dout[ix] * scale[f]; - out[ix] = 1.0f / N * (invar[f]) * (N * dx_hat - dxhat_sum[f] - x_hat * dxhathat_sum[f]); - } - } -} - -void backwardScale(const float *x, const float *mean, const float *invar, const float *delta, int batch, int n, - int size, float *scale_updates) { - size_t i, b, f; - memset(scale_updates, 0, n * sizeof(float)); - for (b = 0; b < batch; ++b) { - for (i = 0; i < size; ++i) { - for (f = 0; f < n; ++f) { - int index = (b * size + i) * n + f; - float x_norm = (x[index] - mean[f]) * invar[f]; - scale_updates[f] += (delta[index] * x_norm); - } + for (int i = 0; i < size; i++) { + for (int c = 0; c < ch; c++) { + // dx_2 + int ix = i * ch + c; + float x_hat = (in[ix] - mean[c]) * invar[c]; + float dx_hat = yt[ix] * scale[c]; + dx[ix] = 1.0f / N * (invar[c]) * (N * dx_hat - dxhat_sum[c] - x_hat * dxhathat_sum[c]); } } } - -void var2Invar(float *save_var, size_t size, float eps) { - for (size_t i = 0; i < size; i++) { - save_var[i] = 1.0f / sqrt(save_var[i] + eps); - } -} diff --git a/mindspore/lite/nnacl/fp32_grad/batch_norm.h b/mindspore/lite/nnacl/fp32_grad/batch_norm.h index 53cc6437da..b3728d6d75 100644 --- a/mindspore/lite/nnacl/fp32_grad/batch_norm.h +++ b/mindspore/lite/nnacl/fp32_grad/batch_norm.h @@ -29,13 +29,9 @@ typedef struct BNGradParameter { extern "C" { #endif -void sumSpatialBatch(const float *in, size_t size, int ch, float *out); -void backwardX(const float *in, const float *dout, const float *scale, const size_t size, int channels, float *mean, - float *invar, float *xhat_sum, float *dxhat_sum, float *out); -void backwardScale(const float *x, const float *mean, const float *invar, const float *delta, int batch, int n, - int size, float *scale_updates); -void var2Invar(float *save_var, size_t size, float eps); - +void var2Invar(float *save_var, int size, float eps); +void backwardAll(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size, + int ch, float *dxhat_sum, float *dxhathat_sum, float *dbias, float *dscale, float *dx); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/fp32_grad/binary_cross_entropy_grad.c b/mindspore/lite/nnacl/fp32_grad/binary_cross_entropy_grad.c index 6965fc051c..95e28c8c65 100644 --- a/mindspore/lite/nnacl/fp32_grad/binary_cross_entropy_grad.c +++ b/mindspore/lite/nnacl/fp32_grad/binary_cross_entropy_grad.c @@ -20,7 +20,7 @@ int BinaryCrossEntropyGrad(const int input_size, const int reduction, const float *input_x, const float *input_y, const float *weight, const float *dloss, float *dx) { - const float epsilon = 1e-12; + const float epsilon = 1e-12f; if (reduction == 0) { for (int i = 0; i < input_size; i++) { float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon); diff --git a/mindspore/lite/nnacl/fp32_grad/gemm.c b/mindspore/lite/nnacl/fp32_grad/gemm.c index 8b3c35f677..8815498f93 100644 --- a/mindspore/lite/nnacl/fp32_grad/gemm.c +++ b/mindspore/lite/nnacl/fp32_grad/gemm.c @@ -21,7 +21,7 @@ #endif #include "nnacl/fp32/matmul_fp32.h" -static void addv(const float *restrict v1, float *restrict v2, float beta, int row, int col, int stride) { +void AddMatrix(const float *restrict v1, float *restrict v2, float beta, int row, int col, int stride) { const float *src_ptr = v1; float *dst_ptr = v2; for (int r = 0; r < row; r++) { @@ -86,7 +86,8 @@ static void RowMajor2Row12MajorStride(const float *src_ptr, float *dst_ptr, int return; } -static void RowMajor2Col12MajorStride(const float *src_ptr, float *dst_ptr, size_t row, size_t col, int lead) { +static void RowMajor2Col12MajorStride(const float *restrict src_ptr, float *restrict dst_ptr, size_t row, size_t col, + int lead) { size_t row_up_12 = UP_ROUND(row, C12NUM); size_t row12 = row / C12NUM * C12NUM; size_t col4 = col / C4NUM * C4NUM; @@ -549,7 +550,7 @@ void GemmMatmulPlus(int ta, int tb, int M, int N, int K, float alpha, const floa #else MatMulOpt(mat_a_input, mat_b_input, output, gcb->bias, gcb->atype, K, M, N, ldc, OutType_Nhwc); #endif - if (incremental) addv(output, mat_c, beta, M, N, ldc); + if (incremental) AddMatrix(output, mat_c, beta, M, N, ldc); gcb->mat_a = mat_a_input; gcb->mat_b = mat_b_input; } diff --git a/mindspore/lite/nnacl/fp32_grad/gemm.h b/mindspore/lite/nnacl/fp32_grad/gemm.h index b1da9b8288..91ae794584 100644 --- a/mindspore/lite/nnacl/fp32_grad/gemm.h +++ b/mindspore/lite/nnacl/fp32_grad/gemm.h @@ -37,6 +37,7 @@ void GemmMatmul(int ta, int tb, int M, int N, int K, float alpha, const float *m int ldb, float beta, float *mat_c, int ldc, float *workspace); int MatSize(int row, int col, int round); int MatSizeTotal(int row, int col, int deep, int inc); +void AddMatrix(const float *v1, float *v2, float beta, int row, int col, int stride); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/fp32_grad/pack_ext.c b/mindspore/lite/nnacl/fp32_grad/pack_ext.c index 1406656ecb..660484aecc 100644 --- a/mindspore/lite/nnacl/fp32_grad/pack_ext.c +++ b/mindspore/lite/nnacl/fp32_grad/pack_ext.c @@ -18,9 +18,8 @@ #include "nnacl/fp32_grad/pack_ext.h" #include "nnacl/pack.h" -static int is_a_ge_zero_and_a_lt_b(int a, int b) { return (unsigned)(a) < (unsigned)(b); } - -void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParameter *conv_param, int rows, int start) { +void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParameter *conv_param, int real_cal_num, + int start) { const int pad_left = conv_param->pad_l_; const int pad_up = conv_param->pad_u_; @@ -43,22 +42,43 @@ void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParamet int kernel_row, kernel_col; - for (int i = 0; i < rows; i++) { - int block_start = start + i; - int input_h = block_start / output_w * stride_h; - int input_w = block_start % output_w * stride_w; - for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { - int input_row = -pad_up + kernel_row * dilation_h + input_h; - for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { - int input_col = -pad_left + kernel_col * dilation_w + input_w; - - if (is_a_ge_zero_and_a_lt_b(input_row, in_height) && is_a_ge_zero_and_a_lt_b(input_col, in_width)) { - const int offset = (input_row * in_width + input_col) * tot_channels; - memcpy(data_col, in_data + offset, sizeof(float) * channels); - data_col += channels; - } else { - memset(data_col, 0, sizeof(float) * channels); - data_col += channels; + if (channels == 1) { + for (int i = 0; i < real_cal_num; i++) { + int block_start = start + i; + int input_h = block_start / output_w * stride_h; + int input_w = block_start % output_w * stride_w; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + input_h; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + input_w; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + const int offset = (input_row * in_width + input_col) * tot_channels; + *data_col = in_data[offset]; + data_col++; + } else { + *data_col = 0; + data_col++; + } + } + } + } + } else { + for (int i = 0; i < real_cal_num; i++) { + int block_start = start + i; + int input_h = block_start / output_w * stride_h; + int input_w = block_start % output_w * stride_w; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + input_h; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + input_w; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + const int offset = (input_row * in_width + input_col) * tot_channels; + memcpy(data_col, in_data + offset, sizeof(float) * channels); + data_col += channels; + } else { + memset(data_col, 0, sizeof(float) * channels); + data_col += channels; + } } } } @@ -70,7 +90,6 @@ void RollingIm2ColPackUnitFp32(const float *input_data, const ConvParameter *con rolling_im2col_hwc(input_data, packed_input, conv_param, real_cal_num, block_index); } -// output matrix is (kernel_h*kernel_w*channels)X(output_h*output_w) void im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, bool transpose) { const int pad_left = conv_param->pad_l_; const int pad_up = conv_param->pad_u_; @@ -100,14 +119,14 @@ void im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { int input_row = -pad_up + kernel_row * dilation_h; for (output_rows = output_h; output_rows; output_rows--) { - if (!is_a_ge_zero_and_a_lt_b(input_row, in_height)) { + if (!((unsigned)(input_row) < (unsigned)(in_height))) { for (output_col = output_w; output_col; output_col--) { *(data_row++) = 0; } } else { int input_col = -pad_left + kernel_col * dilation_w; for (output_col = output_w; output_col; output_col--) { - if (is_a_ge_zero_and_a_lt_b(input_col, in_width)) { + if (((unsigned)(input_col) < (unsigned)(in_width))) { const int offset = (input_row * in_width + input_col) * tot_channels + channel; *(data_row++) = in_data[offset]; } else { @@ -127,14 +146,14 @@ void im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv for (channel = 0; channel < channels; channel++) { int input_row = -pad_up + kernel_row * dilation_h; for (output_rows = output_h; output_rows; output_rows--) { - if (!is_a_ge_zero_and_a_lt_b(input_row, in_height)) { + if (!((unsigned)(input_row) < (unsigned)(in_height))) { for (output_col = output_w; output_col; output_col--) { *(data_row++) = 0; } } else { int input_col = -pad_left + kernel_col * dilation_w; for (output_col = output_w; output_col; output_col--) { - if (is_a_ge_zero_and_a_lt_b(input_col, in_width)) { + if (((unsigned)(input_col) < (unsigned)(in_width))) { const int offset = (input_row * in_width + input_col) * tot_channels + channel; *(data_row++) = in_data[offset]; } else { @@ -150,7 +169,6 @@ void im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv } } } - void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, int rows, int start) { const int pad_left = conv_param->pad_l_; const int pad_up = conv_param->pad_u_; @@ -177,14 +195,14 @@ void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParamet for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { for (output_rows = start; output_rows < start + rows; output_rows++) { int input_row = -pad_up + kernel_row * dilation_h + output_rows * stride_h; - if (!is_a_ge_zero_and_a_lt_b(input_row, in_height)) { + if (!((unsigned)(input_row) < (unsigned)(in_height))) { for (output_col = output_w; output_col; output_col--) { *(data_row++) = 0; } } else { int input_col = -pad_left + kernel_col * dilation_w; for (output_col = output_w; output_col; output_col--) { - if (is_a_ge_zero_and_a_lt_b(input_col, in_width)) { + if (((unsigned)(input_col) < (unsigned)(in_width))) { const int offset = (input_row * in_width + input_col) * tot_channels + channel; *(data_row++) = in_data[offset]; } else { @@ -193,7 +211,6 @@ void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParamet input_col += stride_w; } } - // input_row += stride_h; } } } @@ -232,8 +249,7 @@ void col2im_hwc(const float *data_col, float *data_im, const ConvParameter *conv int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset; for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset; - - if (is_a_ge_zero_and_a_lt_b(input_row, in_height) && is_a_ge_zero_and_a_lt_b(input_col, in_width)) { + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { int offset = (input_row * in_width + input_col) * tot_channels; float *data_im_ptr = &data_im[offset]; for (int i = 0; i < channels; i++) { @@ -271,20 +287,36 @@ void rolling_col2im_hwc(const float *data_col, float *data_im, const ConvParamet int kernel_row, kernel_col; - for (int r = 0; r < rows; r++) { - int output_col = (start + r) % output_w; - int output_row = (start + r) / output_w; - int row_stride_offset = output_row * stride_h; - int col_stride_offset = output_col * stride_w; - - // for (output_col = 0; output_col < output_w; output_col++) - { + if (channels == 1) { + for (int r = 0; r < rows; r++) { + int output_col = (start + r) % output_w; + int output_row = (start + r) / output_w; + int row_stride_offset = output_row * stride_h; + int col_stride_offset = output_col * stride_w; for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset; for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset; - - if (is_a_ge_zero_and_a_lt_b(input_row, in_height) && is_a_ge_zero_and_a_lt_b(input_col, in_width)) { + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { + int offset = (input_row * in_width + input_col) * tot_channels; + float *data_im_ptr = &data_im[offset]; + *data_im_ptr += *data_col; + } + data_col++; + } + } + } + } else { + for (int r = 0; r < rows; r++) { + int output_col = (start + r) % output_w; + int output_row = (start + r) / output_w; + int row_stride_offset = output_row * stride_h; + int col_stride_offset = output_col * stride_w; + for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) { + int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset; + for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) { + int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset; + if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { int offset = (input_row * in_width + input_col) * tot_channels; float *data_im_ptr = &data_im[offset]; for (int i = 0; i < channels; i++) { diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 732d3bdae2..036e329615 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -308,7 +308,7 @@ table SoftmaxCrossEntropy { } table SparseSoftmaxCrossEntropy { - isGrad: int; + isGrad: bool; } table make_tuple { @@ -1225,11 +1225,9 @@ table SmoothL1LossGrad { } table SigmoidCrossEntropyWithLogits { - beta : float; } table SigmoidCrossEntropyWithLogitsGrad { - beta : float; } table Reciprocal { diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt index 71dd6978f9..69ef57da7a 100644 --- a/mindspore/lite/src/CMakeLists.txt +++ b/mindspore/lite/src/CMakeLists.txt @@ -65,7 +65,10 @@ if(SUPPORT_TRAIN) ${CMAKE_CURRENT_SOURCE_DIR}/train/train_populate_parameter.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc ${CMAKE_CURRENT_SOURCE_DIR}/train/train_model.cc - ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc + ${CMAKE_CURRENT_SOURCE_DIR}/train/train_loop.cc + ${CMAKE_CURRENT_SOURCE_DIR}/train/loss_monitor.cc + ${CMAKE_CURRENT_SOURCE_DIR}/train/lr_scheduler.cc + ${CMAKE_CURRENT_SOURCE_DIR}/train/classification_train_accuracy_monitor.cc ) endif() diff --git a/mindspore/lite/src/ops/pad.cc b/mindspore/lite/src/ops/pad.cc index 4473934538..2f03666eb9 100644 --- a/mindspore/lite/src/ops/pad.cc +++ b/mindspore/lite/src/ops/pad.cc @@ -19,6 +19,9 @@ #ifndef PRIMITIVE_WRITEABLE #include "src/ops/ops_register.h" #endif +#ifdef SUPPORT_TRAIN +#include +#endif namespace mindspore { namespace lite { @@ -53,12 +56,20 @@ int Pad::UnPackAttr(const Primitive &prim, const std::vector &inputs } string paddingmode = "REFLECT"; if (prim.GetAttr("mode") == nullptr) { - MS_LOG(ERROR) << "get mode failed!"; - delete this->primitive_; - delete attr; - this->primitive_ = nullptr; - attr = nullptr; - return RET_ERROR; +#ifdef SUPPORT_TRAIN + if (prim.name() == "Pad") { + paddingmode = "CONSTANT"; + } else { +#endif + MS_LOG(ERROR) << "get mode failed!"; + delete this->primitive_; + delete attr; + this->primitive_ = nullptr; + attr = nullptr; + return RET_ERROR; +#ifdef SUPPORT_TRAIN + } +#endif } else { paddingmode = GetValue(prim.GetAttr("mode")); } @@ -66,6 +77,21 @@ int Pad::UnPackAttr(const Primitive &prim, const std::vector &inputs attr->paddingMode = schema::PaddingMode_REFLECT; } else if (paddingmode == "SYMMETRIC") { attr->paddingMode = schema::PaddingMode_SYMMETRIC; +#ifdef SUPPORT_TRAIN + } else if (paddingmode == "CONSTANT") { + attr->paddingMode = schema::PaddingMode_CONSTANT; + if (prim.GetAttr("paddings") != nullptr) { + auto paddings = prim.GetAttr("paddings"); + auto str = (*paddings).ToString(); + std::replace(str.begin(), str.end(), ',', ' '); + std::replace(str.begin(), str.end(), ')', ' '); + std::replace(str.begin(), str.end(), '(', ' '); + std::stringstream ss(str); + for (int i; ss >> i;) { + attr->paddings.push_back(i); + } + } +#endif } else { MS_LOG(ERROR) << "model type not supported!"; delete this->primitive_; diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 769b2eb7cc..e48124ffdc 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -674,7 +674,8 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: } else if ((op_type == "ReluGrad" || op_type == "ReLU6Grad" || op_type == "SigmoidGrad" || op_type == "HSigmoidGrad" || op_type == "HSwishGrad")) { return NewPrimitiveC(prim, inputs, quantType); - } else if ((op_type == "MaxPoolGrad") || (op_type == "AvgPoolGrad") || (op_type == "AvgPoolGradGpu")) { + } else if ((op_type == "MaxPoolGrad") || (op_type == "AvgPoolGrad") || (op_type == "AvgPoolGradGpu") || + (op_type == "AvgPoolGradCpu")) { return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Conv2DBackpropFilter") { return NewPrimitiveC(prim, inputs, quantType); @@ -684,7 +685,7 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "FlattenGrad") { return NewPrimitiveC(prim, inputs, quantType); - } else if (op_type == "FusedBatchNormGrad") { + } else if ((op_type == "FusedBatchNormGrad") || (op_type == "FusedBatchNormGradCpu")) { return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "PowerGrad") { return NewPrimitiveC(prim, inputs, quantType); @@ -714,6 +715,8 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "SigmoidCrossEntropyWithLogitsGrad") { return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Pad") { + return NewPrimitiveC(prim, inputs, quantType); #else } else if (op_type == "Conv2DBackpropInput") { return NewPrimitiveC(prim, inputs, quantType); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1_fp32.cc index 981b2802d4..b2a76cec3d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1_fp32.cc @@ -237,6 +237,9 @@ int Convolution1x1CPUKernel::Run() { MS_LOG(ERROR) << "Conv1x1 Malloc pack_input_ error!"; return RET_MEMORY_FAILED; } + if (IsTrain()) { + PackWeight(); + } for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) { output_ptr_ = src_out + batch_index * matmul_param_->row_ * matmul_param_->col_; @@ -261,4 +264,45 @@ int Convolution1x1CPUKernel::Run() { } return RET_OK; } + +void Convolution1x1CPUKernel::PackWeight() { + auto filter_tensor = in_tensors_.at(kWeightIndex); + auto input_channel = filter_tensor->Channel(); + auto output_channel = filter_tensor->Batch(); + +#ifdef ENABLE_AVX + row_tile_ = C6NUM; + col_tile_ = C16NUM; +#elif defined(ENABLE_SSE) + row_tile_ = C4NUM; + col_tile_ = C8NUM; +#elif defined(ENABLE_ARM32) + row_tile_ = C12NUM; + col_tile_ = C4NUM; +#else + row_tile_ = C12NUM; + col_tile_ = C8NUM; +#endif + + int size = input_channel * UP_ROUND(output_channel, col_tile_) * sizeof(float); + int down_size = input_channel * DOWN_DIV(output_channel, col_tile_) * col_tile_ * sizeof(float); + memset(reinterpret_cast(weight_ptr_) + down_size, 0, size - down_size); +#ifdef ENABLE_AVX + RowMajor2Col16Major(reinterpret_cast(filter_tensor->MutableData()), weight_ptr_, output_channel, + input_channel); +#elif defined(ENABLE_ARM32) + RowMajor2Col4Major(reinterpret_cast(filter_tensor->MutableData()), weight_ptr_, output_channel, + input_channel); +#else + RowMajor2Col8Major(reinterpret_cast(filter_tensor->MutableData()), weight_ptr_, output_channel, + input_channel); +#endif +} + +int Convolution1x1CPUKernel::Eval() { + LiteKernel::Eval(); + PackWeight(); + return RET_OK; +} + } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1_fp32.h index e048214c96..85826b7a0f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_1x1_fp32.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_1X1_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_1X1_H_ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_1X1_FP32_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_1X1_FP32_H_ #include #include @@ -42,6 +42,7 @@ class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel { int Init() override; int Run() override; int ReSize() override; + int Eval() override; public: int DoConv1x1(int task_id); @@ -53,6 +54,7 @@ class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel { void InitConv1x1MatmulParam(); void FreeTmpBuffer(); void PackMatmulInput(const float *src_ptr, float *dst_ptr, int row, int col); + void PackWeight(); private: MatMulParameter *matmul_param_ = nullptr; @@ -70,4 +72,4 @@ class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel { int col_tile_ = 0; }; } // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_1X1_H_ +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_1X1_FP32_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.h index 89a4d1c2c2..09109e9239 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.h @@ -47,6 +47,15 @@ class ConvolutionDelegateCPUKernel : public LiteKernel { static float *CopyData(lite::Tensor *tensor); void FreeCopiedData(); + int Eval() override { + LiteKernel::Eval(); + return conv_kernel_->Eval(); + } + int Train() override { + LiteKernel::Train(); + return conv_kernel_->Train(); + } + protected: bool need_free_weight_ = false; bool need_free_bias_ = false; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3_fp32.cc index db95ea359f..2eb0714fd4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3_fp32.cc @@ -127,6 +127,10 @@ int ConvolutionDepthwise3x3CPUKernel::Run() { return ret; } + if (IsTrain()) { + PackWeight(); + } + auto input_tensor = in_tensors_.at(kInputIndex); input_ptr_ = reinterpret_cast(input_tensor->data_c()); @@ -146,4 +150,18 @@ int ConvolutionDepthwise3x3CPUKernel::Run() { context_->allocator->Free(buffer_); return RET_OK; } + +void ConvolutionDepthwise3x3CPUKernel::PackWeight() { + auto weight_tensor = in_tensors_.at(kWeightIndex); + auto origin_weight = reinterpret_cast(weight_tensor->MutableData()); + PackWeightKHWToHWKFp32(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(), + weight_tensor->Batch()); +} + +int ConvolutionDepthwise3x3CPUKernel::Eval() { + LiteKernel::Eval(); + PackWeight(); + return RET_OK; +} + } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3_fp32.h index e37ab40002..ac9047d7d2 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3_fp32.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_H_ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_FP32_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_FP32_H_ #include #include "src/lite_kernel.h" @@ -37,8 +37,10 @@ class ConvolutionDepthwise3x3CPUKernel : public ConvolutionBaseCPUKernel { int InitWeightBias(); int Execute(int task_id); + int Eval() override; private: + void PackWeight(); int InitBuffer(); SlidingWindowParam *sliding_ = nullptr; float *packed_weight_ = nullptr; @@ -48,4 +50,4 @@ class ConvolutionDepthwise3x3CPUKernel : public ConvolutionBaseCPUKernel { }; } // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_H_ +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_FP32_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.cc index e7f597ac13..579fcc024d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.cc @@ -15,6 +15,7 @@ */ #include "src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.h" +#include #include "src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_fp32.h" #include "src/runtime/kernel/arm/fp32/convolution_depthwise_indirect_fp32.h" #include "schema/model_generated.h" @@ -104,6 +105,10 @@ int ConvDwRun(void *cdata, int task_id) { } int ConvolutionDepthwiseCPUKernel::Run() { + if (IsTrain()) { + PackWeight(); + } + auto input_tensor = in_tensors_.at(kInputIndex); input_ptr_ = reinterpret_cast(input_tensor->MutableData()); @@ -118,6 +123,19 @@ int ConvolutionDepthwiseCPUKernel::Run() { return RET_OK; } +void ConvolutionDepthwiseCPUKernel::PackWeight() { + auto weight_tensor = in_tensors_.at(kWeightIndex); + auto origin_weight = reinterpret_cast(weight_tensor->MutableData()); + PackWeightKHWToHWKFp32(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(), + weight_tensor->Batch()); +} + +int ConvolutionDepthwiseCPUKernel::Eval() { + LiteKernel::Eval(); + PackWeight(); + return RET_OK; +} + kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const InnerContext *ctx, const kernel::KernelKey &desc, diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.h index 27a51372d9..e727e2e35e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_H_ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_FP32_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_FP32_H_ #include #include "src/lite_kernel.h" @@ -37,12 +37,14 @@ class ConvolutionDepthwiseCPUKernel : public ConvolutionBaseCPUKernel { int InitWeightBias(); int Execute(int task_id); + int Eval() override; private: + void PackWeight(); float *packed_weight_ = nullptr; float *input_ptr_ = nullptr; float *output_ptr_ = nullptr; }; } // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_H_ +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_FP32_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_indirect_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_indirect_fp32.cc index f07f1422df..403c838962 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_indirect_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_indirect_fp32.cc @@ -190,6 +190,10 @@ int ConvolutionDepthwiseIndirectCPUKernel::Run() { packed_input_ = input_ptr; } + if (IsTrain()) { + PackWeight(); + } + auto output_tensor = out_tensors_.at(kOutputIndex); output_ptr_ = reinterpret_cast(output_tensor->data_c()); @@ -205,4 +209,23 @@ int ConvolutionDepthwiseIndirectCPUKernel::Run() { } return RET_OK; } + +void ConvolutionDepthwiseIndirectCPUKernel::PackWeight() { + auto weight_tensor = in_tensors_[kWeightIndex]; + auto origin_weight = reinterpret_cast(weight_tensor->MutableData()); +#ifdef ENABLE_AVX + PackDepthwiseIndirectWeightC8Fp32(origin_weight, packed_weight_, weight_tensor->Height(), weight_tensor->Width(), + weight_tensor->Batch()); +#else + PackDepthwiseIndirectWeightC4Fp32(origin_weight, packed_weight_, weight_tensor->Height(), weight_tensor->Width(), + weight_tensor->Batch()); +#endif +} + +int ConvolutionDepthwiseIndirectCPUKernel::Eval() { + LiteKernel::Eval(); + PackWeight(); + return RET_OK; +} + } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_indirect_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_indirect_fp32.h index e1ffda12c9..49eae476c2 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_indirect_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_indirect_fp32.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_INDIRECT_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_INDIRECT_H_ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_INDIRECT_FP32_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_INDIRECT_FP32_H_ #include #include "src/lite_kernel.h" @@ -37,10 +37,12 @@ class ConvolutionDepthwiseIndirectCPUKernel : public ConvolutionBaseCPUKernel { int InitWeightBias(); int Execute(int task_id); + int Eval() override; private: int MallocIndirectBuffer(); int MallocPackedInput(); + void PackWeight(); int step_w = 0; int step_h = 0; float **indirect_buffer_ = nullptr; @@ -51,4 +53,4 @@ class ConvolutionDepthwiseIndirectCPUKernel : public ConvolutionBaseCPUKernel { }; } // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_INDIRECT_H_ +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_INDIRECT_FP32_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_fp32.cc index 45d0ea6750..7fa4b8b4c1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_fp32.cc @@ -145,6 +145,11 @@ int ConvolutionDepthwiseSWCPUKernel::Run() { FreePackedInputOutput(); return RET_ERROR; } + + if (IsTrain()) { + PackWeight(); + } + auto input_tensor = in_tensors_.at(kInputIndex); auto input_ptr = reinterpret_cast(input_tensor->MutableData()); @@ -183,4 +188,18 @@ void ConvolutionDepthwiseSWCPUKernel::FreePackedInputOutput() { packed_output_ = nullptr; } } + +void ConvolutionDepthwiseSWCPUKernel::PackWeight() { + auto weight_tensor = in_tensors_.at(kWeightIndex); + auto origin_weight = reinterpret_cast(weight_tensor->MutableData()); + PackNCHWToNC4HW4Fp32(origin_weight, packed_weight_, 1, weight_tensor->Height() * weight_tensor->Width(), + weight_tensor->Batch()); +} + +int ConvolutionDepthwiseSWCPUKernel::Eval() { + LiteKernel::Eval(); + PackWeight(); + return RET_OK; +} + } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_fp32.h index 12c8cbc1dc..360223cdc4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_fp32.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_SLIDEWINDOW_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_SLIDEWINDOW_H_ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_SLIDEWINDOW_FP32_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_SLIDEWINDOW_FP32_H_ #include #include "src/lite_kernel.h" @@ -37,10 +37,12 @@ class ConvolutionDepthwiseSWCPUKernel : public ConvolutionBaseCPUKernel { int InitWeightBias(); int Execute(int task_id); + int Eval() override; private: int InitPackedInputOutput(); void FreePackedInputOutput(); + void PackWeight(); SlidingWindowParam *sliding_ = nullptr; float *packed_weight_ = nullptr; float *packed_input_ = nullptr; @@ -49,4 +51,4 @@ class ConvolutionDepthwiseSWCPUKernel : public ConvolutionBaseCPUKernel { }; } // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_SLIDEWINDOW_H_ +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_SLIDEWINDOW_FP32_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.cc index bf01d96bab..06262b8a9a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.cc @@ -150,6 +150,9 @@ int ConvolutionCPUKernel::Run() { FreeTmpBuffer(); return RET_ERROR; } + if (IsTrain()) { + PackWeight(); + } ret = ParallelLaunch(this->context_->thread_pool_, ConvolutionImpl, this, thread_count_); if (ret != RET_OK) { @@ -158,4 +161,37 @@ int ConvolutionCPUKernel::Run() { FreeTmpBuffer(); return ret; } + +void ConvolutionCPUKernel::PackWeight() { + auto filter_tensor = in_tensors_.at(kWeightIndex); + int in_channel = filter_tensor->Channel(); + int out_channel = filter_tensor->Batch(); + int kernel_plane = filter_tensor->Height() * filter_tensor->Width(); +#ifdef ENABLE_AVX + const int oc_block = C16NUM; +#elif ENABLE_ARM32 + const int oc_block = C4NUM; +#else + const int oc_block = C8NUM; +#endif + int oc_block_num = UP_ROUND(out_channel, oc_block); + int pack_weight_size = oc_block_num * in_channel * kernel_plane; + + auto origin_weight = reinterpret_cast(filter_tensor->data_c()); + memset(packed_weight_, 0, pack_weight_size * sizeof(float)); +#ifdef ENABLE_AVX + RowMajor2Col16Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane); +#elif ENABLE_ARM32 + RowMajor2Col4Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane); +#else + RowMajor2Col8Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane); +#endif +} + +int ConvolutionCPUKernel::Eval() { + LiteKernel::Eval(); + PackWeight(); + return RET_OK; +} + } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.h index b5a4762f90..e3eb5a5649 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_fp32.h @@ -46,7 +46,10 @@ class ConvolutionCPUKernel : public ConvolutionBaseCPUKernel { int Run() override; virtual int RunImpl(int task_id); + int Eval() override; + protected: + void PackWeight(); void FreeTmpBuffer() { if (packed_input_ != nullptr) { ctx_->allocator->Free(packed_input_); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.cc index c5b87ef621..f6282df105 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.cc @@ -58,10 +58,12 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() { // set data auto trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * oc_block * sizeof(float); - trans_weight_ = reinterpret_cast(malloc(trans_matrix_data_size)); if (trans_weight_ == nullptr) { - MS_LOG(ERROR) << "malloc matrix_buffer failed."; - return RET_MEMORY_FAILED; + trans_weight_ = reinterpret_cast(malloc(trans_matrix_data_size)); + if (trans_weight_ == nullptr) { + MS_LOG(ERROR) << "malloc matrix_buffer failed."; + return RET_MEMORY_FAILED; + } } memset(trans_weight_, 0, trans_matrix_data_size); @@ -217,6 +219,9 @@ int ConvolutionWinogradCPUKernel::Run() { FreeTmpBuffer(); return RET_ERROR; } + if (IsTrain()) { + InitWeightBias(); + } ret = ParallelLaunch(this->context_->thread_pool_, ConvolutionWinogradImpl, this, thread_count_); if (ret != RET_OK) { @@ -226,4 +231,11 @@ int ConvolutionWinogradCPUKernel::Run() { FreeTmpBuffer(); return ret; } + +int ConvolutionWinogradCPUKernel::Eval() { + LiteKernel::Eval(); + InitWeightBias(); + return RET_OK; +} + } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.h index 6e9c26efaa..f9f2edb3c3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_WINOGRAD_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_WINOGRAD_H_ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_WINOGRAD_FP32_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_WINOGRAD_FP32_H_ #include #include "src/lite_kernel.h" @@ -43,6 +43,7 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel { int Init() override; int ReSize() override; int Run() override; + int Eval() override; int RunImpl(int task_id); int InitWeightBias(); int InitTmpBuffer(); @@ -84,4 +85,4 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel { }; } // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_WINOGRAD_H_ +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_WINOGRAD_FP32_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.cc index 35c3e2c057..4aea7ab019 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.cc @@ -48,33 +48,27 @@ int ActivationGradCPUKernel::DoActivation(int task_id) { auto output_addr = reinterpret_cast(out_tensors_.at(0)->MutableData()); int length = in_tensors_.at(0)->ElementsNum(); - int stride = UP_DIV(length, 1); + int stride = UP_DIV(length, thread_count_); int count = MSMIN(stride, length - stride * task_id); + size_t start = stride * task_id; auto error_code = RET_OK; if (param_act_grad_->type_ == schema::ActivationType_RELU) { - error_code = - ReluGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id); + error_code = ReluGrad(yt_addr + start, input_addr + start, count, output_addr + start); } else if (param_act_grad_->type_ == schema::ActivationType_RELU6) { - error_code = - Relu6Grad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id); + error_code = Relu6Grad(yt_addr + start, input_addr + start, count, output_addr + start); } else if (param_act_grad_->type_ == schema::ActivationType_LEAKY_RELU) { - error_code = LReluGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, - output_addr + stride * task_id, param_act_grad_->alpha_); + error_code = LReluGrad(yt_addr + start, input_addr + start, count, output_addr + start, param_act_grad_->alpha_); } else if (param_act_grad_->type_ == schema::ActivationType_SIGMOID) { // Sigmoid gets the input tensors in reverse order! - error_code = - SigmoidGrad(input_addr + stride * task_id, yt_addr + stride * task_id, count, output_addr + stride * task_id); + error_code = SigmoidGrad(input_addr + start, yt_addr + start, count, output_addr + start); } else if (param_act_grad_->type_ == schema::ActivationType_TANH) { - error_code = - TanhGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id); + error_code = TanhGrad(yt_addr + start, input_addr + start, count, output_addr + start); } else if (param_act_grad_->type_ == schema::ActivationType_HSWISH) { - error_code = - HSwishGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id); + error_code = HSwishGrad(yt_addr + start, input_addr + start, count, output_addr + start); } else if (param_act_grad_->type_ == schema::ActivationType_HSIGMOID) { - error_code = - HSigmoidGrad(yt_addr + stride * task_id, input_addr + stride * task_id, count, output_addr + stride * task_id); + error_code = HSigmoidGrad(yt_addr + start, input_addr + start, count, output_addr + start); } else { MS_LOG(ERROR) << "Activation type error"; return RET_ERROR; @@ -97,7 +91,7 @@ int ActivationGradRun(void *cdata, int task_id) { } int ActivationGradCPUKernel::Run() { - int error_code = ParallelLaunch(this->context_->thread_pool_, ActivationGradRun, this, 1); + int error_code = ParallelLaunch(this->context_->thread_pool_, ActivationGradRun, this, thread_count_); if (error_code != RET_OK) { MS_LOG(ERROR) << "Activation Grad function error error_code[" << error_code << "]"; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.h index f56b9ec9cc..50c3ddb5f7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/activation_grad.h @@ -27,7 +27,7 @@ class ActivationGradCPUKernel : public LiteKernel { explicit ActivationGradCPUKernel(OpParameter *param, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(param, inputs, outputs, ctx, primitive) { + : LiteKernel(param, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) { param_act_grad_ = reinterpret_cast(param); } ~ActivationGradCPUKernel() override = default; @@ -39,6 +39,7 @@ class ActivationGradCPUKernel : public LiteKernel { private: ActivationParameter *param_act_grad_; + int thread_count_; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.cc index 9dac7deaa0..8691b7e254 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.cc @@ -33,17 +33,23 @@ namespace mindspore::kernel { int AdamCPUKernel::ReSize() { return RET_OK; } int AdamCPUKernel::Execute(int task_id) { - auto weight = reinterpret_cast(in_tensors_[0]->MutableData()); - auto m = reinterpret_cast(in_tensors_[1]->MutableData()); - auto v = reinterpret_cast(in_tensors_[2]->MutableData()); - auto beta1_power = reinterpret_cast(in_tensors_[3]->MutableData())[0]; - auto beta2_power = reinterpret_cast(in_tensors_[4]->MutableData())[0]; - auto learning_rate = reinterpret_cast(in_tensors_[5]->MutableData())[0]; - auto beta1 = reinterpret_cast(in_tensors_[6]->MutableData())[0]; - auto beta2 = reinterpret_cast(in_tensors_[7]->MutableData())[0]; - auto eps = reinterpret_cast(in_tensors_[8]->MutableData())[0]; - auto gradient = reinterpret_cast(in_tensors_[9]->MutableData()); - size_t elem_num = in_tensors_[0]->ElementsNum(); + auto weight = reinterpret_cast(in_tensors_.at(0)->MutableData()); + auto m = reinterpret_cast(in_tensors_.at(1)->MutableData()); + auto v = reinterpret_cast(in_tensors_.at(2)->MutableData()); + auto beta1_power = reinterpret_cast(in_tensors_.at(3)->MutableData())[0]; + auto beta2_power = reinterpret_cast(in_tensors_.at(4)->MutableData())[0]; + auto learning_rate = reinterpret_cast(in_tensors_.at(5)->MutableData())[0]; + auto beta1 = reinterpret_cast(in_tensors_.at(6)->MutableData())[0]; + auto beta2 = reinterpret_cast(in_tensors_.at(7)->MutableData())[0]; + auto eps = reinterpret_cast(in_tensors_.at(8)->MutableData())[0]; + auto gradient = reinterpret_cast(in_tensors_.at(9)->MutableData()); + size_t length = in_tensors_.at(0)->ElementsNum(); + + size_t stride = UP_DIV(length, thread_count_); + size_t count = MSMIN(stride, length - stride * task_id); + + size_t start = stride * task_id; + size_t end = start + count; if ((1.f - beta1_power) <= 0.0f) { MS_LOG(ERROR) << "divisor cannot be 0 or below"; @@ -55,17 +61,19 @@ int AdamCPUKernel::Execute(int task_id) { } auto update_lr = learning_rate * std::sqrt(1.f - beta2_power) / (1.f - beta1_power); + const float one_minus_beta1 = 1.f - beta1; + const float one_minus_beta2 = 1.f - beta2; if (adam_param_->use_nesterov_) { // Nadam - for (size_t i = 0; i < elem_num; ++i) { - m[i] += (gradient[i] - m[i]) * (1.f - beta1); - v[i] += (gradient[i] * gradient[i] - v[i]) * (1.f - beta2); - weight[i] -= update_lr * (m[i] * beta1 + (1.f - beta1) * gradient[i]) / (std::sqrt(v[i]) + eps); + for (size_t i = start; i < end; ++i) { + m[i] += (gradient[i] - m[i]) * one_minus_beta1; + v[i] += (gradient[i] * gradient[i] - v[i]) * one_minus_beta2; + weight[i] -= update_lr * (m[i] * beta1 + one_minus_beta1 * gradient[i]) / (std::sqrt(v[i]) + eps); } } else { - for (size_t i = 0; i < elem_num; ++i) { - m[i] += (gradient[i] - m[i]) * (1.f - beta1); - v[i] += (gradient[i] * gradient[i] - v[i]) * (1.f - beta2); + for (size_t i = start; i < end; ++i) { + m[i] += (gradient[i] - m[i]) * one_minus_beta1; + v[i] += (gradient[i] * gradient[i] - v[i]) * one_minus_beta2; weight[i] -= update_lr * m[i] / (std::sqrt(v[i]) + eps); } } @@ -84,7 +92,7 @@ int AdamRun(void *cdata, int task_id) { } int AdamCPUKernel::Run() { - int error_code = ParallelLaunch(this->context_->thread_pool_, AdamRun, this, 1); + int error_code = ParallelLaunch(this->context_->thread_pool_, AdamRun, this, thread_count_); if (error_code != RET_OK) { MS_LOG(ERROR) << "Adam function error error_code[" << error_code << "]"; return RET_ERROR; @@ -92,6 +100,17 @@ int AdamCPUKernel::Run() { return RET_OK; } +int AdamCPUKernel::SetLearningRate(float lr) { + auto learning_rate_tensor = reinterpret_cast(in_tensors_.at(5)->MutableData()); + learning_rate_tensor[0] = lr; + return RET_OK; +} + +float AdamCPUKernel::GetLearningRate() { + auto learning_rate_tensor = reinterpret_cast(in_tensors_.at(5)->MutableData()); + return learning_rate_tensor[0]; +} + int AdamCPUKernel::Init() { return RET_OK; } kernel::LiteKernel *CpuAdamFp32KernelCreator(const std::vector &inputs, diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.h index 66a387c5cf..a891c97ed2 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/adam.h @@ -18,25 +18,28 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_ADAM_H_ #include -#include "src/lite_kernel.h" +#include "src/train/optimizer_kernel.h" #include "nnacl/fp32_grad/optimizer.h" namespace mindspore::kernel { -class AdamCPUKernel : public LiteKernel { +class AdamCPUKernel : public OptimizerKernel { public: explicit AdamCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + : OptimizerKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) { adam_param_ = reinterpret_cast(parameter); } ~AdamCPUKernel() override {} int Init() override; int ReSize() override; int Run() override; + int SetLearningRate(float lr) override; + float GetLearningRate() override; int Execute(int task_id); private: + int thread_count_; AdamParameter *adam_param_; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.cc index 6213df7697..8c6ac3046b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.cc @@ -31,20 +31,26 @@ namespace mindspore::kernel { int ApplyMomentumCPUKernel::ReSize() { return RET_OK; } int ApplyMomentumCPUKernel::Execute(int task_id) { - auto weight = reinterpret_cast(in_tensors_[0]->MutableData()); - auto accumulate = reinterpret_cast(in_tensors_[1]->MutableData()); - float learning_rate = reinterpret_cast(in_tensors_[2]->MutableData())[0]; - auto gradient = reinterpret_cast(in_tensors_[3]->MutableData()); - float moment = reinterpret_cast(in_tensors_[4]->MutableData())[0]; - size_t elem_num = in_tensors_[0]->ElementsNum(); + auto weight = reinterpret_cast(in_tensors_.at(0)->MutableData()); + auto accumulate = reinterpret_cast(in_tensors_.at(1)->MutableData()); + float learning_rate = reinterpret_cast(in_tensors_.at(2)->MutableData())[0]; + auto gradient = reinterpret_cast(in_tensors_.at(3)->MutableData()); + float moment = reinterpret_cast(in_tensors_.at(4)->MutableData())[0]; + size_t length = in_tensors_.at(0)->ElementsNum(); + + size_t stride = UP_DIV(length, thread_count_); + size_t count = MSMIN(stride, length - stride * task_id); + + size_t start = stride * task_id; + size_t end = start + count; if (apply_momentum_param_->use_nesterov_) { - for (size_t i = 0; i < elem_num; ++i) { + for (size_t i = start; i < end; ++i) { accumulate[i] = accumulate[i] * moment + gradient[i]; weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate; } } else { - for (size_t i = 0; i < elem_num; ++i) { + for (size_t i = start; i < end; ++i) { accumulate[i] = accumulate[i] * moment + gradient[i]; weight[i] -= accumulate[i] * learning_rate; } @@ -64,7 +70,7 @@ int ApplyMomentumRun(void *cdata, int task_id) { } int ApplyMomentumCPUKernel::Run() { - int error_code = ParallelLaunch(this->context_->thread_pool_, ApplyMomentumRun, this, 1); + int error_code = ParallelLaunch(this->context_->thread_pool_, ApplyMomentumRun, this, thread_count_); if (error_code != RET_OK) { MS_LOG(ERROR) << "Apply Momentum function error error_code[" << error_code << "]"; return RET_ERROR; @@ -74,6 +80,17 @@ int ApplyMomentumCPUKernel::Run() { int ApplyMomentumCPUKernel::Init() { return RET_OK; } +int ApplyMomentumCPUKernel::SetLearningRate(float lr) { + auto learning_rate_tensor = reinterpret_cast(in_tensors_.at(2)->MutableData()); + learning_rate_tensor[0] = lr; + return RET_OK; +} + +float ApplyMomentumCPUKernel::GetLearningRate() { + auto learning_rate_tensor = reinterpret_cast(in_tensors_.at(2)->MutableData()); + return learning_rate_tensor[0]; +} + kernel::LiteKernel *CpuApplyMomentumFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::InnerContext *ctx, diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.h index bef6154d4c..95a4c0c73a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/apply_momentum.h @@ -18,16 +18,18 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_APPLY_MOMENTUM_H_ #include -#include "src/lite_kernel.h" +#include "src/train/optimizer_kernel.h" #include "nnacl/fp32_grad/optimizer.h" namespace mindspore::kernel { -class ApplyMomentumCPUKernel : public LiteKernel { +class ApplyMomentumCPUKernel : public OptimizerKernel { public: explicit ApplyMomentumCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive), apply_momentum_param_(nullptr) { + : OptimizerKernel(parameter, inputs, outputs, ctx, primitive), + thread_count_(ctx->thread_num_), + apply_momentum_param_(nullptr) { apply_momentum_param_ = reinterpret_cast(parameter); } ~ApplyMomentumCPUKernel() override {} @@ -35,8 +37,11 @@ class ApplyMomentumCPUKernel : public LiteKernel { int ReSize() override; int Run() override; int Execute(int task_id); + int SetLearningRate(float lr) override; + float GetLearningRate() override; private: + int thread_count_; ApplyMomentumParameter *apply_momentum_param_; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_self_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_self_grad.cc index 479e56877d..56eed20633 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_self_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_self_grad.cc @@ -49,27 +49,24 @@ int ArithmeticSelfGradCPUKernel::Init() { return RET_OK; } -int ArithmeticSelfGradCPUKernel::DoArithmeticSelfGrad(int thread_id) { - auto dy = reinterpret_cast(in_tensors_[0]->MutableData()); - auto in_x = reinterpret_cast(in_tensors_[1]->MutableData()); - auto dx = reinterpret_cast(out_tensors_[0]->MutableData()); - int dy_size = in_tensors_.at(0)->ElementsNum(); - int size = MSMIN(thread_stride_, static_cast(dy_size - thread_id * thread_stride_)); - if (size <= 0) { - return RET_OK; - } - int offset = thread_id * thread_stride_; - (*self_grad_operation_)(dy + offset, in_x + offset, dx + offset, size); +int ArithmeticSelfGradCPUKernel::DoArithmeticSelfGrad(int task_id) { + auto dy = reinterpret_cast(in_tensors_.at(0)->MutableData()); + auto in_x = reinterpret_cast(in_tensors_.at(1)->MutableData()); + auto dx = reinterpret_cast(out_tensors_.at(0)->MutableData()); + size_t length = in_tensors_.at(0)->ElementsNum(); + + size_t stride = UP_DIV(length, thread_count_); + size_t count = MSMIN(stride, length - stride * task_id); + size_t start = stride * task_id; + + (*self_grad_operation_)(dy + start, in_x + start, dx + start, count); return RET_OK; } int ArithmeticSelfGradCPUKernel::ReSize() { return RET_OK; } int ArithmeticSelfGradCPUKernel::Run() { - int dy_size = in_tensors_.at(0)->ElementsNum(); - op_parameter_->thread_num_ = MSMIN(op_parameter_->thread_num_, static_cast(dy_size)); - thread_stride_ = UP_DIV(dy_size, op_parameter_->thread_num_); - auto ret = ParallelLaunch(this->context_->thread_pool_, ArithmeticSelfGradRun, this, op_parameter_->thread_num_); + auto ret = ParallelLaunch(this->context_->thread_pool_, ArithmeticSelfGradRun, this, thread_count_); if (ret != RET_OK) { MS_LOG(ERROR) << "parallel launch fail!ret: " << ret; return ret; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_self_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_self_grad.h index 37a2995dc7..28b90c9045 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_self_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/arithmetic_self_grad.h @@ -30,7 +30,7 @@ class ArithmeticSelfGradCPUKernel : public LiteKernel { ArithmeticSelfGradCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} + : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {} ~ArithmeticSelfGradCPUKernel() override {} int Init() override; int ReSize() override; @@ -38,7 +38,7 @@ class ArithmeticSelfGradCPUKernel : public LiteKernel { int DoArithmeticSelfGrad(int thread_id); private: - int thread_stride_; + int thread_count_; ArithmeticSelfGradOperation self_grad_operation_; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/assign.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/assign.cc index df4b6a686a..75d83def8c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/assign.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/assign.cc @@ -32,11 +32,16 @@ namespace mindspore::kernel { int AssignCPUKernel::ReSize() { return RET_OK; } int AssignCPUKernel::Execute(int task_id) { - auto x = reinterpret_cast(in_tensors_[0]->MutableData()); - auto y = reinterpret_cast(in_tensors_[1]->MutableData()); - size_t size = in_tensors_[0]->Size(); + auto x = reinterpret_cast(in_tensors_.at(0)->MutableData()); + auto y = reinterpret_cast(in_tensors_.at(1)->MutableData()); + size_t length = in_tensors_.at(0)->ElementsNum(); - memcpy(x, y, size); + size_t stride = UP_DIV(length, thread_count_); + size_t count = MSMIN(stride, length - stride * task_id); + + size_t start = stride * task_id; + + memcpy(&(x[start]), &(y[start]), count * sizeof(float)); return RET_OK; } @@ -52,7 +57,7 @@ int AssignRun(void *cdata, int task_id) { } int AssignCPUKernel::Run() { - int error_code = ParallelLaunch(this->context_->thread_pool_, AssignRun, this, 1); + int error_code = ParallelLaunch(this->context_->thread_pool_, AssignRun, this, thread_count_); if (error_code != RET_OK) { MS_LOG(ERROR) << "Assign function error error_code[" << error_code << "]"; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/assign.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/assign.h index dd2575e62a..0da097de21 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/assign.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/assign.h @@ -27,12 +27,15 @@ class AssignCPUKernel : public LiteKernel { explicit AssignCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} + : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {} ~AssignCPUKernel() override {} int Init() override; int ReSize() override; int Run() override; int Execute(int task_id); + + protected: + int thread_count_ = 1; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bias_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bias_grad.cc index f6b4e86269..5635d90714 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bias_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bias_grad.cc @@ -29,7 +29,7 @@ using mindspore::schema::PrimitiveType_BiasGrad; namespace mindspore::kernel { -int BiasGradCPUKernel::Init() { +int BiasGradCPUKernel::ReSize() { auto dims = in_tensors_[0]->shape(); bias_param->ndim_ = dims.size(); for (unsigned int i = 0; i < bias_param->ndim_; i++) { @@ -44,7 +44,12 @@ int BiasGradCPUKernel::Init() { return RET_OK; } -int BiasGradCPUKernel::ReSize() { return RET_OK; } +int BiasGradCPUKernel::Init() { + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} int BiasGradCPUKernel::Execute(int task_id) { auto in = reinterpret_cast(in_tensors_.at(0)->MutableData()); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc index aa614ce17f..90a3c8d72e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/bn_grad.cc @@ -31,17 +31,16 @@ using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_BNGrad; namespace mindspore::kernel { -int BNGradCPUKernel::Init() { +int BNGradCPUKernel::ReSize() { auto *input_x = in_tensors_.at(1); int channels = input_x->shape().at(kNHWC_C); set_workspace_size(2 * channels * sizeof(float)); return RET_OK; } -int BNGradCPUKernel::ReSize() { return RET_OK; } +int BNGradCPUKernel::Init() { return ReSize(); } int BNGradCPUKernel::Execute(int task_id) { - auto bn_param = reinterpret_cast(op_parameter_); auto *input_yt = in_tensors_.at(0); auto *input_x = in_tensors_.at(1); auto *input_scale = in_tensors_.at(2); @@ -54,10 +53,9 @@ int BNGradCPUKernel::Execute(int task_id) { auto *output_dx = out_tensors_.at(0); auto *output_scale = out_tensors_.at(1); auto *output_bias = out_tensors_.at(2); - size_t batch = input_x->Batch(); - size_t channels = input_x->Channel(); - size_t spatial = input_x->Height() * input_x->Width(); - float eps = bn_param->epsilon_; + int32_t batch = input_x->Batch(); + int32_t channels = input_x->Channel(); + int32_t spatial = input_x->Height() * input_x->Width(); float *workspace_temp = static_cast(workspace()); std::fill(workspace_temp, workspace_temp + workspace_size() / sizeof(*workspace_temp), 0.f); @@ -68,34 +66,32 @@ int BNGradCPUKernel::Execute(int task_id) { float *yt = reinterpret_cast(input_yt->MutableData()); float *scale = reinterpret_cast(input_scale->MutableData()); float *dx = reinterpret_cast(output_dx->MutableData()); - float *dscale = reinterpret_cast(output_scale->MutableData()); float *dbias = reinterpret_cast(output_bias->MutableData()); - - var2Invar(save_var, input_var->ElementsNum(), eps); - // dx - backwardX(x, yt, scale, batch * spatial, channels, save_mean, save_var, dxhat_sum, dxhathat_sum, dx); - // dbias - sumSpatialBatch(yt, batch * spatial, channels, dbias); - // dscale - backwardScale(x, save_mean, save_var, yt, batch, channels, spatial, dscale); - + float *dscale = reinterpret_cast(output_scale->MutableData()); + std::fill(dbias, dbias + channels, 0.f); + std::fill(dscale, dscale + channels, 0.f); + backwardAll(x, yt, save_mean, save_var, scale, batch * spatial, channels, dxhat_sum, dxhathat_sum, dbias, dscale, dx); return RET_OK; } int BNGradRun(void *cdata, int task_id) { MS_ASSERT(cdata != nullptr); auto bn_kernel = reinterpret_cast(cdata); - if (task_id == 0) { - auto error_code = bn_kernel->Execute(task_id); - if (error_code != RET_OK) { - MS_LOG(ERROR) << "BNGradRun error task_id[" << task_id << "] error_code[" << error_code << "]"; - return RET_ERROR; - } + + auto error_code = bn_kernel->Execute(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "BNGradRun error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; } return RET_OK; } int BNGradCPUKernel::Run() { + auto *input_var = in_tensors_.at(4); + float *save_var = reinterpret_cast(input_var->MutableData()); + auto bn_param = reinterpret_cast(op_parameter_); + float eps = bn_param->epsilon_; + var2Invar(save_var, input_var->ElementsNum(), eps); int error_code = ParallelLaunch(this->context_->thread_pool_, BNGradRun, this, 1); if (error_code != RET_OK) { MS_LOG(ERROR) << "BN function error error_code[" << error_code << "]"; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.cc index e109b221ed..24beda510f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.cc @@ -26,7 +26,7 @@ using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; namespace mindspore::kernel { -int ConvolutionTrainCPUKernel::Init() { +int ConvolutionTrainCPUKernel::ReSize() { if (in_tensors_.size() < 2) { MS_LOG(ERROR) << "Convolution should have at least two inputs"; return RET_ERROR; @@ -54,13 +54,21 @@ int ConvolutionTrainCPUKernel::Init() { conv_param_->group_ = (conv_param_->group_ == 0) ? conv_param_->input_channel_ : conv_param_->group_; const int n = conv_param_->output_channel_ * conv_param_->group_; const int k = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ / conv_param_->group_; - ws_size = chunk * k; - int mat_alloc = MatSizeTotal(chunk, n, k, 0); - set_workspace_size((ws_size + mat_alloc) * sizeof(float)); + ws_size_ = chunk_ * k; + int mat_alloc = MatSizeTotal(chunk_, n, k, 0); + set_workspace_size((ws_size_ + mat_alloc) * sizeof(float)); + + do_img2col_ = (conv_param_->kernel_h_ == 1) && (conv_param_->kernel_w_ == 1) && (conv_param_->pad_d_ == 0) && + (conv_param_->pad_u_ == 0) && (conv_param_->pad_l_ == 0) && (conv_param_->pad_r_ == 0) && + (conv_param_->dilation_h_ == 1) && (conv_param_->dilation_w_ == 1) && + (conv_param_->stride_h_ == 1) && (conv_param_->stride_w_ == 1) && (conv_param_->group_ == 1) + ? false + : true; + return RET_OK; } -int ConvolutionTrainCPUKernel::ReSize() { return RET_OK; } +int ConvolutionTrainCPUKernel::Init() { return ReSize(); } int ConvolutionTrainCPUKernel::Execute(int task_id) { auto conv_param_ = reinterpret_cast(op_parameter_); @@ -87,17 +95,34 @@ int ConvolutionTrainCPUKernel::Execute(int task_id) { const int n = out_ch / groups; const int k = k_h * k_w * in_ch / groups; float *workspace_temp = static_cast(workspace()); - float *mat_workspace = workspace_temp + ws_size; - for (int i = 0; i < batch; ++i) { - for (int j = 0; j < groups; ++j) { - for (int ci = 0; ci < m; ci += chunk) { - int real_chunk = MSMIN(m - ci, chunk); - float *mat_a = workspace_temp; - const float *mat_b = w_addr + j * nweights / groups; - float *mat_c = y_addr + (i * groups) * n * m + j * (out_ch / groups) + ci * out_ch; - float *im = x_addr + (i * groups) * (in_ch / groups) * in_h * in_w + j * (in_ch / groups); - RollingIm2ColPackUnitFp32(im, conv_param_, mat_a, real_chunk, ci); - GemmMatmul(0, 1, real_chunk, n, k, 1, mat_a, k, mat_b, k, 0, mat_c, out_ch, mat_workspace); + float *mat_workspace = workspace_temp + ws_size_; + + if (do_img2col_) { + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < groups; ++j) { + for (int ci = 0; ci < m; ci += chunk_) { + int real_chunk = MSMIN(m - ci, chunk_); + float *mat_a = workspace_temp; + const float *mat_b = w_addr + j * nweights / groups; + float *mat_c = y_addr + (i * groups) * n * m + j * (out_ch / groups) + ci * out_ch; + float *im = x_addr + i * in_ch * in_h * in_w + j * (in_ch / groups); + RollingIm2ColPackUnitFp32(im, conv_param_, mat_a, real_chunk, ci); + GemmMatmul(0, 1, real_chunk, n, k, 1, mat_a, k, mat_b, k, 0, mat_c, out_ch, mat_workspace); + } + } + } + } else { + const float *mat_b = w_addr; + const size_t in_plane_size = in_ch * in_h * in_w; + for (int i = 0; i < batch; ++i) { + float *im = x_addr + i * in_plane_size; + for (int ci = 0; ci < m; ci += chunk_) { + int real_chunk = MSMIN(m - ci, chunk_); + float *mat_c = y_addr + i * n * m + ci * out_ch; + int input_height = ci / out_w * conv_param_->stride_h_; + int input_width = ci % out_w * conv_param_->stride_w_; + int offset = (input_height * in_w + input_width) * in_ch; + GemmMatmul(0, 1, real_chunk, n, k, 1, im + offset, k, mat_b, k, 0, mat_c, out_ch, mat_workspace); } } } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.h index dd212e7f87..bfaf9d25ea 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution.h @@ -35,11 +35,12 @@ class ConvolutionTrainCPUKernel : public LiteKernel { int Execute(int task_id); private: - int ws_size = 0; + int ws_size_ = 0; + bool do_img2col_ = true; #ifdef ENABLE_ARM32 - const int chunk = C4NUM; + const int chunk_ = C4NUM * 2; #else - const int chunk = C12NUM; + const int chunk_ = C12NUM * 2; #endif }; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.cc index 432f9e5eaa..d872cceb9f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.cc @@ -29,7 +29,7 @@ using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_Conv2DGradFilter; namespace mindspore::kernel { -int ConvolutionGradFilterCPUKernel::Init() { +int ConvolutionGradFilterCPUKernel::ReSize() { // dy is in input 0 // x is in input 1 // dw is output 0 @@ -51,16 +51,25 @@ int ConvolutionGradFilterCPUKernel::Init() { conv_param->output_h_ = dy_tensor->shape()[kNHWC_H]; conv_param->output_w_ = dy_tensor->shape()[kNHWC_W]; - ws_size = chunk * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_; + ws_size_ = chunk_ * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_; int n = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_; int k = conv_param->output_channel_ / conv_param->group_; - size_t mat_alloc = MatSizeTotal(k, n, chunk, n); - set_workspace_size((ws_size + mat_alloc) * sizeof(float)); + int thread_num = context_->thread_num_; + mat_alloc_ = MatSizeTotal(k, n, chunk_, 0); + set_workspace_size((ws_size_ + mat_alloc_ + (k * n)) * thread_num * sizeof(float)); + + do_img2col_ = (conv_param->kernel_h_ == 1) && (conv_param->kernel_w_ == 1) && (conv_param->pad_d_ == 0) && + (conv_param->pad_u_ == 0) && (conv_param->pad_l_ == 0) && (conv_param->pad_r_ == 0) && + (conv_param->dilation_h_ == 1) && (conv_param->dilation_w_ == 1) && (conv_param->stride_h_ == 1) && + (conv_param->stride_w_ == 1) && (conv_param->group_ == 1) + ? false + : true; + return RET_OK; } -int ConvolutionGradFilterCPUKernel::ReSize() { return RET_OK; } +int ConvolutionGradFilterCPUKernel::Init() { return ReSize(); } int ConvolutionGradFilterCPUKernel::Execute(int task_id) { auto conv_param = reinterpret_cast(op_parameter_); @@ -72,7 +81,6 @@ int ConvolutionGradFilterCPUKernel::Execute(int task_id) { auto dy_addr = reinterpret_cast(input_dy->MutableData()); auto dw_addr = reinterpret_cast(out_dw->MutableData()); - int i, j; int nweights = out_dw->ElementsNum(); int in_ch = conv_param->input_channel_; int in_h = conv_param->input_h_; @@ -88,22 +96,45 @@ int ConvolutionGradFilterCPUKernel::Execute(int task_id) { int m = out_h * out_w; int n = k_h * k_w * in_ch / groups; int k = out_ch / groups; - + int thread_num = context_->thread_num_; float *workspace_temp = reinterpret_cast(workspace()); - float *mat_workspace = workspace_temp + ws_size; - // zero out pointer - memset(dw_addr, 0, out_dw->Size()); - for (i = 0; i < batch; ++i) { - for (j = 0; j < groups; ++j) { - for (int ci = 0; ci < m; ci += chunk) { - int real_chunk = MSMIN(m - ci, chunk); - float *mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups) + ci * out_ch; - float *mat_b = workspace_temp; - float *mat_c = dw_addr + j * nweights / groups; - float *im = x_addr + (i * in_ch * in_h * in_w) + j * (in_ch / groups); - memset(mat_b, 0, n * real_chunk * sizeof(float)); - RollingIm2ColPackUnitFp32(im, conv_param, mat_b, real_chunk, ci); - GemmMatmul(1, 0, k, n, real_chunk, 1, mat_a, out_ch, mat_b, n, 1, mat_c, n, mat_workspace); + float *mat_workspace = workspace_temp + ws_size_ * thread_num + task_id * (mat_alloc_ + k * n); + float *mat_tmp = mat_workspace + mat_alloc_; + int stride = UP_DIV(batch, thread_num); + int count = MSMIN(stride, batch - stride * task_id); + int start = stride * task_id; + int end = start + count; + + if (do_img2col_) { + for (int i = start; i < end; ++i) { + for (int j = 0; j < groups; ++j) { + for (int ci = 0; ci < m; ci += chunk_) { + int real_chunk = MSMIN(m - ci, chunk_); + float *mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups) + ci * out_ch; + float *mat_b = workspace_temp + task_id * ws_size_; + float *mat_c = dw_addr + j * nweights / groups; + float *im = x_addr + (i * in_ch * in_h * in_w) + j * (in_ch / groups); + RollingIm2ColPackUnitFp32(im, conv_param, mat_b, real_chunk, ci); + GemmMatmul(1, 0, k, n, real_chunk, 1, mat_a, out_ch, mat_b, n, 0, mat_tmp, n, mat_workspace); + std::unique_lock merge_lock(lock_); + AddMatrix(mat_tmp, mat_c, 1, k, n, n); + } + } + } + } else { + float *mat_c = dw_addr; + const size_t in_plane_size = in_ch * in_h * in_w; + for (int i = start; i < end; ++i) { + for (int ci = 0; ci < m; ci += chunk_) { + int real_chunk = MSMIN(m - ci, chunk_); + float *mat_a = dy_addr + i * m * k + ci * out_ch; + float *im = x_addr + i * in_plane_size; + int input_h = ci / out_w * conv_param->stride_h_; + int input_w = ci % out_w * conv_param->stride_w_; + int offset = (input_h * in_w + input_w) * in_ch; + GemmMatmul(1, 0, k, n, real_chunk, 1, mat_a, out_ch, im + offset, n, 0, mat_tmp, n, mat_workspace); + std::unique_lock merge_lock(lock_); + AddMatrix(mat_tmp, mat_c, 1, k, n, n); } } } @@ -122,7 +153,10 @@ int ConvolutionGradFilterRun(void *cdata, int task_id) { } int ConvolutionGradFilterCPUKernel::Run() { - int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionGradFilterRun, this, 1); + auto *out_dw = out_tensors_.at(0); + auto dw_addr = reinterpret_cast(out_dw->MutableData()); + memset(dw_addr, 0, out_dw->Size()); + int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionGradFilterRun, this, context_->thread_num_); if (error_code != RET_OK) { MS_LOG(ERROR) << "conv filter function error error_code[" << error_code << "]"; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h index 465a91e57a..bf1696ab65 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_filter.h @@ -36,11 +36,14 @@ class ConvolutionGradFilterCPUKernel : public LiteKernel { int Execute(int task_id); private: - size_t ws_size = 0; + size_t ws_size_ = 0; + bool do_img2col_ = true; + std::mutex lock_; + size_t mat_alloc_ = 0; #ifdef ENABLE_ARM32 - const int chunk = C4NUM; + const int chunk_ = C4NUM * 2; #else - const int chunk = C12NUM; + const int chunk_ = C12NUM * 2; #endif }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.cc index 0ad2d7c459..4276961202 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.cc @@ -30,7 +30,7 @@ using mindspore::schema::PrimitiveType_Conv2DGradInput; using mindspore::schema::PrimitiveType_GroupConv2DGradInput; namespace mindspore::kernel { -int ConvolutionGradInputCPUKernel::Init() { +int ConvolutionGradInputCPUKernel::ReSize() { auto *dy_tensor = in_tensors_.at(kInputIndex); MS_ASSERT(dy_tensor != nullptr); auto *weight_tensor = in_tensors_.at(kWeightIndex); @@ -51,18 +51,17 @@ int ConvolutionGradInputCPUKernel::Init() { conv_param->output_h_ = dy_tensor->shape()[kNHWC_H]; conv_param->output_w_ = dy_tensor->shape()[kNHWC_W]; - ws_size = chunk * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_; + ws_size_ = chunk_ * conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_ / conv_param->group_; int n = conv_param->kernel_w_ * conv_param->kernel_h_ * conv_param->input_channel_ / conv_param->group_; int k = conv_param->output_channel_ / conv_param->group_; - - size_t mat_alloc = MatSizeTotal(chunk, n, k, 0); - - set_workspace_size((ws_size + mat_alloc) * sizeof(float)); + int thread_num = context_->thread_num_; + mat_alloc_ = MatSizeTotal(chunk_, n, k, 0); + set_workspace_size((ws_size_ + mat_alloc_) * sizeof(float) * thread_num); return RET_OK; } -int ConvolutionGradInputCPUKernel::ReSize() { return RET_OK; } +int ConvolutionGradInputCPUKernel::Init() { return ReSize(); } int ConvolutionGradInputCPUKernel::Execute(int task_id) { auto conv_param = reinterpret_cast(op_parameter_); @@ -86,17 +85,21 @@ int ConvolutionGradInputCPUKernel::Execute(int task_id) { int groups = conv_param->group_; int out_h = conv_param->output_h_; int out_w = conv_param->output_w_; - + int thread_num = context_->thread_num_; int m = out_h * out_w; int n = k_w * k_h * in_ch / groups; int k = out_ch / groups; - float *workspace_temp = reinterpret_cast(workspace()); - float *mat_workspace = workspace_temp + ws_size; - memset(dx_addr, 0, sizeof(float) * batch * in_ch * in_h * in_w); - for (i = 0; i < batch; ++i) { + float *workspace_temp = reinterpret_cast(workspace()) + task_id * (mat_alloc_ + ws_size_); + float *mat_workspace = workspace_temp + ws_size_; + int stride = UP_DIV(batch, thread_num); + int count = MSMIN(stride, batch - stride * task_id); + int start = stride * task_id; + int end = start + count; + + for (i = start; i < end; ++i) { for (j = 0; j < groups; ++j) { GemmCb gcb; - for (int ci = 0; ci < m; ci += chunk) { + for (int ci = 0; ci < m; ci += chunk_) { float *mat_b = nullptr; if (ci == 0) { mat_b = w_addr + j * nweights / groups; @@ -108,7 +111,7 @@ int ConvolutionGradInputCPUKernel::Execute(int task_id) { mat_b = gcb.mat_b; gcb.cb = 1; } - int real_chunk = MSMIN(m - ci, chunk); + int real_chunk = MSMIN(m - ci, chunk_); float *mat_a = dy_addr + (i * groups) * m * k + j * (out_ch / groups) + ci * out_ch; float *mat_c = workspace_temp; GemmMatmulPlus(0, 0, real_chunk, n, k, 1, mat_a, out_ch, mat_b, n, 0, mat_c, n, mat_workspace, &gcb); @@ -133,7 +136,15 @@ int ConvolutionGradInputRun(void *cdata, int task_id) { } int ConvolutionGradInputCPUKernel::Run() { - int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionGradInputRun, this, 1); + auto conv_param = reinterpret_cast(op_parameter_); + int batch = conv_param->output_batch_; + int in_ch = conv_param->input_channel_; + int in_h = conv_param->input_h_; + int in_w = conv_param->input_w_; + auto *out_dx = out_tensors_.at(0); + auto dx_addr = reinterpret_cast(out_dx->MutableData()); + memset(dx_addr, 0, sizeof(float) * batch * in_ch * in_h * in_w); + int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionGradInputRun, this, context_->thread_num_); if (error_code != RET_OK) { MS_LOG(ERROR) << "bias function error error_code[" << error_code << "]"; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.h index d4b226dd9b..4578992d0b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/convolution_grad_input.h @@ -35,11 +35,12 @@ class ConvolutionGradInputCPUKernel : public LiteKernel { int Execute(int task_id); private: - size_t ws_size = 0; + size_t ws_size_ = 0; + size_t mat_alloc_ = 0; #ifdef ENABLE_ARM32 - const int chunk = C4NUM; + const int chunk_ = C4NUM; #else - const int chunk = C12NUM; + const int chunk_ = C12NUM; #endif }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout.cc index 7fa2eafa8b..374b182af7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout.cc @@ -61,19 +61,26 @@ int DropoutCPUKernel::Execute(int task_id) { auto input_ptr = reinterpret_cast(in_tensors_.at(kInputIndex)->MutableData()); auto output_ptr = reinterpret_cast(out_tensors_.at(kOutputIndex)->MutableData()); auto mask = reinterpret_cast(out_tensors_.at(1)->MutableData()); - auto length = in_tensors_.at(kInputIndex)->ElementsNum(); auto param = reinterpret_cast(op_parameter_); + auto length = in_tensors_.at(kInputIndex)->ElementsNum(); + + int stride = UP_DIV(length, thread_count_); + int count = MSMIN(stride, length - stride * task_id); + + size_t start = stride * task_id; + size_t end = start + count; + if (param == nullptr) { MS_LOG(ERROR) << "Dropout op_parameter_ nullptr"; return RET_NULL_PTR; } if (IsEval()) { - std::copy(input_ptr, input_ptr + length, output_ptr); + std::copy(&(input_ptr[start]), &(input_ptr[end]), &(output_ptr[start])); } else { std::default_random_engine generator; std::bernoulli_distribution distribution(param->ratio_); - for (int i = 0; i < length; i++) { + for (size_t i = start; i < end; i++) { mask[i] = distribution(generator); output_ptr[i] = input_ptr[i] * mask[i] * scale_; } @@ -92,7 +99,7 @@ int RunDropout(void *cdata, int task_id) { } int DropoutCPUKernel::Run() { - int error_code = ParallelLaunch(this->context_->thread_pool_, RunDropout, this, 1); + int error_code = ParallelLaunch(this->context_->thread_pool_, RunDropout, this, thread_count_); if (error_code != RET_OK) { MS_LOG(ERROR) << "Dropout function error error_code[" << error_code << "]"; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout.h index 2b4093a73b..dbfe0252bf 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout.h @@ -25,7 +25,7 @@ class DropoutCPUKernel : public LiteKernel { DropoutCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} + : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {} ~DropoutCPUKernel() override = default; @@ -35,7 +35,8 @@ class DropoutCPUKernel : public LiteKernel { int Execute(int task_id); private: - float scale_; + float scale_ = 1.0; + int thread_count_ = 1; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout_grad.cc index bb62ba40f8..544fe81433 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout_grad.cc @@ -62,7 +62,13 @@ int DropoutGradCPUKernel::Execute(int task_id) { auto mask_ptr = reinterpret_cast(in_tensors_.at(1)->MutableData()); auto output_ptr = reinterpret_cast(out_tensors_.at(kOutputIndex)->MutableData()); auto length = in_tensors_.at(kInputIndex)->ElementsNum(); - DropoutGrad(yt_ptr, mask_ptr, output_ptr, length, scale_); + + int stride = UP_DIV(length, thread_count_); + int count = MSMIN(stride, length - stride * task_id); + + size_t start = stride * task_id; + + DropoutGrad(&(yt_ptr[start]), &(mask_ptr[start]), &(output_ptr[start]), count, scale_); return RET_OK; } @@ -78,7 +84,7 @@ int RunDropoutGrad(void *cdata, int task_id) { } int DropoutGradCPUKernel::Run() { - int error_code = ParallelLaunch(this->context_->thread_pool_, RunDropoutGrad, this, 1); + int error_code = ParallelLaunch(this->context_->thread_pool_, RunDropoutGrad, this, thread_count_); if (error_code != RET_OK) { MS_LOG(ERROR) << "Dropout Grad function error error_code[" << error_code << "]"; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout_grad.h index 1740656dde..11e79a53f4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/dropout_grad.h @@ -25,7 +25,7 @@ class DropoutGradCPUKernel : public LiteKernel { DropoutGradCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} + : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {} ~DropoutGradCPUKernel() override = default; @@ -36,6 +36,7 @@ class DropoutGradCPUKernel : public LiteKernel { private: float scale_; + int thread_count_ = 1; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/neg_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/neg_grad.cc index 0b6b17fc33..826cd01156 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/neg_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/neg_grad.cc @@ -29,36 +29,34 @@ using mindspore::schema::PrimitiveType_NegGrad; namespace mindspore::kernel { namespace { -int NegGradRun(void *cdata, int thread_id) { +int NegGradRun(void *cdata, int task_id) { MS_ASSERT(cdata != nullptr); auto kernel = reinterpret_cast(cdata); MS_ASSERT(kernel != nullptr); - return kernel->DoNegGrad(thread_id); + return kernel->DoNegGrad(task_id); } } // namespace int NegGradCPUKernel::Init() { return RET_OK; } -int NegGradCPUKernel::DoNegGrad(int thread_id) { - auto dy = reinterpret_cast(in_tensors_[0]->MutableData()); - auto dx = reinterpret_cast(out_tensors_[0]->MutableData()); - int dy_size = in_tensors_.at(0)->ElementsNum(); - int size = MSMIN(thread_stride_, static_cast(dy_size - thread_id * thread_stride_)); - if (size <= 0) { - return RET_OK; - } - int offset = thread_id * thread_stride_; - ElementNegative(dy + offset, dx + offset, size); +int NegGradCPUKernel::DoNegGrad(int task_id) { + auto dy = reinterpret_cast(in_tensors_.at(0)->MutableData()); + auto dx = reinterpret_cast(out_tensors_.at(0)->MutableData()); + size_t length = in_tensors_.at(0)->ElementsNum(); + + size_t stride = UP_DIV(length, thread_count_); + size_t count = MSMIN(stride, length - stride * task_id); + + size_t start = stride * task_id; + + ElementNegative(dy + start, dx + start, count); return RET_OK; } int NegGradCPUKernel::ReSize() { return RET_OK; } int NegGradCPUKernel::Run() { - int dy_size = in_tensors_.at(0)->ElementsNum(); - op_parameter_->thread_num_ = MSMIN(op_parameter_->thread_num_, static_cast(dy_size)); - thread_stride_ = UP_DIV(dy_size, op_parameter_->thread_num_); - auto ret = ParallelLaunch(this->context_->thread_pool_, NegGradRun, this, op_parameter_->thread_num_); + auto ret = ParallelLaunch(this->context_->thread_pool_, NegGradRun, this, thread_count_); if (ret != RET_OK) { MS_LOG(ERROR) << "parallel launch fail!ret: " << ret; return ret; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/neg_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/neg_grad.h index fdbda2a18b..2c2f5aad07 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/neg_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/neg_grad.h @@ -28,7 +28,7 @@ class NegGradCPUKernel : public LiteKernel { explicit NegGradCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} + : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {} ~NegGradCPUKernel() override {} int Init() override; int ReSize() override; @@ -36,7 +36,7 @@ class NegGradCPUKernel : public LiteKernel { int DoNegGrad(int thread_id); private: - int thread_stride_; + int thread_count_; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc index 32c3a4f24e..13cfbfd6d5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc @@ -29,7 +29,7 @@ using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_PoolingGrad; namespace mindspore::kernel { -int PoolingGradCPUKernel::Init() { +int PoolingGradCPUKernel::ReSize() { PoolingParameter *pool_param = reinterpret_cast(op_parameter_); auto in_shape = in_tensors_.at(0)->shape(); @@ -59,7 +59,7 @@ int PoolingGradCPUKernel::Init() { return RET_OK; } -int PoolingGradCPUKernel::ReSize() { return RET_OK; } +int PoolingGradCPUKernel::Init() { return ReSize(); } int PoolingGradCPUKernel::Execute(int task_id) { PoolingParameter *pool_param = reinterpret_cast(op_parameter_); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/power_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/power_grad.cc index aacf5ec282..0cb9b59532 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/power_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/power_grad.cc @@ -45,13 +45,20 @@ int PowerGradCPUKernel::Execute(int task_id) { auto dy_addr = reinterpret_cast(in_tensors_.at(0)->MutableData()); auto x_addr = reinterpret_cast(in_tensors_.at(1)->MutableData()); auto dx_addr = reinterpret_cast(out_tensors_.at(0)->MutableData()); - auto size = in_tensors_.at(0)->ElementsNum(); + + size_t length = in_tensors_.at(0)->ElementsNum(); + + size_t stride = UP_DIV(length, thread_count_); + size_t count = MSMIN(stride, length - stride * task_id); + + size_t start = stride * task_id; + size_t end = start + count; float exp = power_ - 1; - Power(x_addr, &exp, dx_addr, size, scale_, shift_, true); - ElementMul(dx_addr, dy_addr, dx_addr, size); + Power(&(x_addr[start]), &exp, &(dx_addr[start]), count, scale_, shift_, true); + ElementMul(&(dx_addr[start]), &(dy_addr[start]), &(dx_addr[start]), count); float scale = scale_ * power_; - for (int i = 0; i < size; i++) { + for (size_t i = start; i < end; i++) { dx_addr[i] *= scale; } @@ -69,7 +76,7 @@ int PowerGradRun(void *cdata, int task_id) { } int PowerGradCPUKernel::Run() { - int error_code = ParallelLaunch(this->context_->thread_pool_, PowerGradRun, this, 1); + int error_code = ParallelLaunch(this->context_->thread_pool_, PowerGradRun, this, thread_count_); if (error_code != RET_OK) { MS_LOG(ERROR) << "power grad function error error_code[" << error_code << "]"; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/power_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/power_grad.h index 8b1702c53a..4ce3cf3622 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/power_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/power_grad.h @@ -27,7 +27,7 @@ class PowerGradCPUKernel : public LiteKernel { PowerGradCPUKernel(OpParameter *param, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(param, inputs, outputs, ctx, primitive) { + : LiteKernel(param, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) { PowerParameter *power_param = reinterpret_cast(param); power_ = power_param->power_; scale_ = power_param->scale_; @@ -41,6 +41,7 @@ class PowerGradCPUKernel : public LiteKernel { int Execute(int task_id); private: + int thread_count_; float power_; float scale_; float shift_; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sgd.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sgd.cc index 731b766a50..05436b865b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sgd.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sgd.cc @@ -16,6 +16,7 @@ */ #include "src/runtime/kernel/arm/fp32_grad/sgd.h" +#include #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "include/errorcode.h" @@ -37,36 +38,42 @@ int SgdCPUKernel::Execute(int task_id) { float learning_rate = reinterpret_cast(in_tensors_.at(2)->MutableData())[0]; auto gradient = reinterpret_cast(in_tensors_.at(1)->MutableData()); float moment = reinterpret_cast(in_tensors_.at(4)->MutableData())[0]; - size_t elem_num = in_tensors_.at(0)->ElementsNum(); auto stat = reinterpret_cast(in_tensors_.at(5)->MutableData()); + size_t length = in_tensors_.at(0)->ElementsNum(); + + size_t stride = UP_DIV(length, thread_count_); + size_t count = MSMIN(stride, length - stride * task_id); - if (stat[0] > 0) { - stat[0] = 0; - memcpy(accumulate, gradient, elem_num * sizeof(float)); + size_t start = stride * task_id; + size_t end = start + count; + + if (stat[task_id] > 0) { + stat[task_id] = 0; // Haim Please approve this + std::copy(&(gradient[start]), &(gradient[end]), &(accumulate[start])); if (sgd_param_->use_nesterov_) { - for (size_t i = 0; i < elem_num; ++i) { + for (size_t i = start; i < end; ++i) { weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate; } } else { - for (size_t i = 0; i < elem_num; ++i) { + for (size_t i = start; i < end; ++i) { weight[i] -= accumulate[i] * learning_rate; } } } else { if (moment > 0.f) { if (sgd_param_->use_nesterov_) { - for (size_t i = 0; i < elem_num; ++i) { + for (size_t i = start; i < end; ++i) { accumulate[i] = accumulate[i] * moment + gradient[i] * (1.f - sgd_param_->dampening_); weight[i] -= (accumulate[i] * moment + gradient[i]) * learning_rate; } } else { - for (size_t i = 0; i < elem_num; ++i) { + for (size_t i = start; i < end; ++i) { accumulate[i] = accumulate[i] * moment + gradient[i] * (1.f - sgd_param_->dampening_); weight[i] -= accumulate[i] * learning_rate; } } } else { - for (size_t i = 0; i < elem_num; ++i) { + for (size_t i = start; i < end; ++i) { weight[i] -= gradient[i] * learning_rate; } } @@ -85,7 +92,7 @@ int SgdRun(void *cdata, int task_id) { } int SgdCPUKernel::Run() { - int error_code = ParallelLaunch(this->context_->thread_pool_, SgdRun, this, 1); + int error_code = ParallelLaunch(this->context_->thread_pool_, SgdRun, this, thread_count_); if (error_code != RET_OK) { MS_LOG(ERROR) << "SGD function error error_code[" << error_code << "]"; return RET_ERROR; @@ -114,6 +121,17 @@ int SgdCPUKernel::Init() { return RET_OK; } +int SgdCPUKernel::SetLearningRate(float lr) { + auto learning_rate_tensor = reinterpret_cast(in_tensors_.at(2)->MutableData()); + learning_rate_tensor[0] = lr; + return RET_OK; +} + +float SgdCPUKernel::GetLearningRate() { + auto learning_rate_tensor = reinterpret_cast(in_tensors_.at(2)->MutableData()); + return learning_rate_tensor[0]; +} + kernel::LiteKernel *CpuSgdFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::InnerContext *ctx, const kernel::KernelKey &desc, diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sgd.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sgd.h index 355d0ed1e2..4f8143aceb 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sgd.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/sgd.h @@ -18,16 +18,18 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_GRAD_SGD_H_ #include -#include "src/lite_kernel.h" +#include "src/train/optimizer_kernel.h" #include "nnacl/fp32_grad/optimizer.h" namespace mindspore::kernel { -class SgdCPUKernel : public LiteKernel { +class SgdCPUKernel : public OptimizerKernel { public: explicit SgdCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive), sgd_param_(nullptr) { + : OptimizerKernel(parameter, inputs, outputs, ctx, primitive), + thread_count_(ctx->thread_num_), + sgd_param_(nullptr) { sgd_param_ = reinterpret_cast(parameter); } ~SgdCPUKernel() override {} @@ -35,8 +37,11 @@ class SgdCPUKernel : public LiteKernel { int ReSize() override; int Run() override; int Execute(int task_id); + int SetLearningRate(float lr) override; + float GetLearningRate() override; private: + int thread_count_; SgdParameter *sgd_param_; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/smooth_l1_loss.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/smooth_l1_loss.cc index d6e1c5cb10..b9b4bc13b8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/smooth_l1_loss.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/smooth_l1_loss.cc @@ -35,12 +35,19 @@ int SmoothL1LossCPUKernel::Execute(int task_id) { auto target = reinterpret_cast(in_tensors_.at(1)->MutableData()); auto *out = reinterpret_cast(out_tensors_.at(0)->MutableData()); - const size_t tensor_len = in_tensors_.at(0)->ElementsNum(); + const size_t length = in_tensors_.at(0)->ElementsNum(); + + size_t stride = UP_DIV(length, thread_count_); + int count = MSMIN(stride, length - stride * task_id); + + size_t start = stride * task_id; + size_t end = start + count; + const float zero = 0.0f; const float half = 0.5f; const float beta = smooth_l1_loss_param->beta_; - for (uint64_t i = 0; i < tensor_len; ++i) { + for (uint64_t i = start; i < end; ++i) { float diff = predict[i] - target[i]; if (diff < zero) { diff = -diff; @@ -66,7 +73,7 @@ int SmoothL1LossRun(void *cdata, int task_id) { } int SmoothL1LossCPUKernel::Run() { - int error_code = ParallelLaunch(this->context_->thread_pool_, SmoothL1LossRun, this, 1); + int error_code = ParallelLaunch(this->context_->thread_pool_, SmoothL1LossRun, this, thread_count_); if (error_code != RET_OK) { MS_LOG(ERROR) << "SmoothL1Loss function error error_code[" << error_code << "]"; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/smooth_l1_loss.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/smooth_l1_loss.h index 335e91d2ac..5fd0c0dea2 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/smooth_l1_loss.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/smooth_l1_loss.h @@ -27,7 +27,9 @@ class SmoothL1LossCPUKernel : public LiteKernel { explicit SmoothL1LossCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive), smooth_l1_param_(nullptr) { + : LiteKernel(parameter, inputs, outputs, ctx, primitive), + smooth_l1_param_(nullptr), + thread_count_(ctx->thread_num_) { smooth_l1_param_ = reinterpret_cast(parameter); } ~SmoothL1LossCPUKernel() override {} @@ -38,6 +40,7 @@ class SmoothL1LossCPUKernel : public LiteKernel { private: SmoothL1LossParameter *smooth_l1_param_; + int thread_count_ = 1; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/smooth_l1_loss_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/smooth_l1_loss_grad.cc index a0685b95b5..4f6f99d418 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/smooth_l1_loss_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/smooth_l1_loss_grad.cc @@ -36,10 +36,17 @@ int SmoothL1LossGradCPUKernel::Execute(int task_id) { auto d_loss = reinterpret_cast(in_tensors_.at(2)->MutableData()); auto *out = reinterpret_cast(out_tensors_.at(0)->MutableData()); - const size_t tensor_len = in_tensors_.at(0)->ElementsNum(); + const size_t length = in_tensors_.at(0)->ElementsNum(); + + size_t stride = UP_DIV(length, thread_count_); + size_t count = MSMIN(stride, length - stride * task_id); + + size_t start = stride * task_id; + size_t end = start + count; + const float beta = smooth_l1_loss_param->beta_; - for (uint64_t i = 0; i < tensor_len; ++i) { + for (uint64_t i = start; i < end; ++i) { float diff = predict[i] - target[i]; if (diff > beta) { out[i] = d_loss[i]; @@ -63,7 +70,7 @@ int SmoothL1LossGradRun(void *cdata, int task_id) { } int SmoothL1LossGradCPUKernel::Run() { - int error_code = ParallelLaunch(this->context_->thread_pool_, SmoothL1LossGradRun, this, 1); + int error_code = ParallelLaunch(this->context_->thread_pool_, SmoothL1LossGradRun, this, thread_count_); if (error_code != RET_OK) { MS_LOG(ERROR) << "SmoothL1LossGrad function error error_code[" << error_code << "]"; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/smooth_l1_loss_grad.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/smooth_l1_loss_grad.h index 9bc049f9d3..e702519e20 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/smooth_l1_loss_grad.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/smooth_l1_loss_grad.h @@ -27,7 +27,9 @@ class SmoothL1LossGradCPUKernel : public LiteKernel { explicit SmoothL1LossGradCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive), smooth_l1_param_(nullptr) { + : LiteKernel(parameter, inputs, outputs, ctx, primitive), + smooth_l1_param_(nullptr), + thread_count_(ctx->thread_num_) { smooth_l1_param_ = reinterpret_cast(parameter); } ~SmoothL1LossGradCPUKernel() override {} @@ -38,6 +40,7 @@ class SmoothL1LossGradCPUKernel : public LiteKernel { private: SmoothL1LossParameter *smooth_l1_param_; + int thread_count_; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_cross_entropy_with_logits.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_cross_entropy_with_logits.cc index 7368e7cf05..3dd7cb25b7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_cross_entropy_with_logits.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/softmax_cross_entropy_with_logits.cc @@ -29,7 +29,7 @@ using mindspore::schema::PrimitiveType_SoftmaxCrossEntropy; namespace mindspore::kernel { -int SoftmaxCrossEntropyWithLogitsCPUKernel::ReSize() { return RET_OK; } +int SoftmaxCrossEntropyWithLogitsCPUKernel::Init() { return ReSize(); } void SoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const float *labels, const float *logits, float *grads, float *output2) const { @@ -100,7 +100,7 @@ int SoftmaxCrossEntropyWithLogitsCPUKernel::Run() { return RET_OK; } -int SoftmaxCrossEntropyWithLogitsCPUKernel::Init() { +int SoftmaxCrossEntropyWithLogitsCPUKernel::ReSize() { auto dims = in_tensors_.at(0)->shape(); param_->n_dim_ = 2; param_->number_of_classes_ = dims.at(1); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/tuple_getitem.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/tuple_getitem.cc index 090f4c714a..55f0717de2 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/tuple_getitem.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/tuple_getitem.cc @@ -15,6 +15,7 @@ */ #include +#include #include "src/runtime/kernel/arm/fp32_grad/tuple_getitem.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" @@ -47,7 +48,15 @@ int TupleGetItemCPUKernel::Execute(int task_id) { auto in = reinterpret_cast(in_tensors_.at(0)->MutableData()); auto out = reinterpret_cast(out_tensors_.at(0)->MutableData()); - memcpy(out, in, in_tensors_.at(0)->Size()); + size_t length = in_tensors_.at(0)->ElementsNum(); + + size_t stride = UP_DIV(length, thread_count_); + size_t count = MSMIN(stride, length - stride * task_id); + + size_t start = stride * task_id; + size_t end = start + count; + + std::copy(&(in[start]), &(in[end]), &(out[start])); return RET_OK; } @@ -62,7 +71,7 @@ int TupleRun(void *cdata, int task_id) { } int TupleGetItemCPUKernel::Run() { - int error_code = ParallelLaunch(this->context_->thread_pool_, TupleRun, this, 1); + int error_code = ParallelLaunch(this->context_->thread_pool_, TupleRun, this, thread_count_); if (error_code != RET_OK) { MS_LOG(ERROR) << "tuple function error error_code[" << error_code << "]"; return RET_ERROR; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/tuple_getitem.h b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/tuple_getitem.h index 7bd93fc560..9a7d470b57 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/tuple_getitem.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/tuple_getitem.h @@ -27,7 +27,7 @@ class TupleGetItemCPUKernel : public LiteKernel { explicit TupleGetItemCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, const lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) { param = parameter; } ~TupleGetItemCPUKernel() override = default; @@ -38,6 +38,7 @@ class TupleGetItemCPUKernel : public LiteKernel { int Execute(int task_id); private: + int thread_count_ = 1; OpParameter *param; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/train/classification_train_accuracy_monitor.cc b/mindspore/lite/src/train/classification_train_accuracy_monitor.cc new file mode 100644 index 0000000000..12a1f35421 --- /dev/null +++ b/mindspore/lite/src/train/classification_train_accuracy_monitor.cc @@ -0,0 +1,98 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "include/train/classification_train_accuracy_monitor.h" +#include +#include +#include +#include +#include +#include +#include +#include "include/errorcode.h" +#include "include/train_session.h" +#include "src/common/utils.h" +#include "src/tensor.h" +#include "src/train/loss_kernel.h" +#include "src/train/optimizer_kernel.h" +#include "src/sub_graph_kernel.h" +#include "src/train/train_populate_parameter.h" +#include "src/runtime/runtime_api.h" +#include "src/executor.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/fp32_grad/convolution.h" + +namespace mindspore { +namespace lite { + +void ClassificationTrainAccuracyMonitor::Begin(const session::TrainLoopCallBackData &cb_data) { + if (cb_data.epoch_ == 0) accuracies_.clear(); +} + +void ClassificationTrainAccuracyMonitor::EpochBegin(const session::TrainLoopCallBackData &cb_data) { + if (accuracies_.size() != cb_data.epoch_) { + MS_LOG(WARNING) << "Accuracies array does not match epoch number"; + } else { + accuracies_.push_back(std::make_pair(cb_data.epoch_, 0.0)); + } +} + +int ClassificationTrainAccuracyMonitor::EpochEnd(const session::TrainLoopCallBackData &cb_data) { + if (cb_data.step_ > 0) accuracies_.at(cb_data.epoch_).second /= static_cast(cb_data.step_); + if ((cb_data.epoch_ + 1) % print_every_n_ == 0) { + std::cout << cb_data.epoch_ + 1 << ":\tTraining Accuracy is " << accuracies_.at(cb_data.epoch_).second << std::endl; + } + return mindspore::session::RET_CONTINUE; +} + +void ClassificationTrainAccuracyMonitor::StepEnd(const session::TrainLoopCallBackData &cb_data) { + auto inputs = cb_data.session_->GetInputs(); + auto outputs = cb_data.session_->GetPredictions(); + auto labels = reinterpret_cast(inputs.at(1)->MutableData()); + for (auto it = outputs.begin(); it != outputs.end(); ++it) { + if (it->second->ElementsNum() == inputs.at(1)->ElementsNum()) { + int batch_size = inputs.at(1)->shape().at(0); + int num_of_classes = inputs.at(1)->shape().at(1); + auto predictions = reinterpret_cast(it->second->MutableData()); + float accuracy = 0.0; + for (int b = 0; b < batch_size; b++) { + int label = 0; + int max_idx = 0; + float max_label_score = labels[num_of_classes * b]; + float max_score = predictions[num_of_classes * b]; + for (int c = 1; c < num_of_classes; c++) { + if (predictions[num_of_classes * b + c] > max_score) { + max_score = predictions[num_of_classes * b + c]; + max_idx = c; + } + if (labels[num_of_classes * b + c] > max_label_score) { + max_label_score = labels[num_of_classes * b + c]; + label = c; + } + } + if (label == max_idx) accuracy += 1.0; + } + accuracy /= static_cast(batch_size); + accuracies_.at(cb_data.epoch_).second = accuracy; + return; + } + } + + MS_LOG(WARNING) << "Model does not have a loss output tensor of size 1"; +} + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/train/loss_monitor.cc b/mindspore/lite/src/train/loss_monitor.cc new file mode 100644 index 0000000000..e8a4b94629 --- /dev/null +++ b/mindspore/lite/src/train/loss_monitor.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "include/train/loss_monitor.h" +#include +#include +#include +#include +#include +#include +#include +#include "include/errorcode.h" +#include "include/train_session.h" +#include "src/common/utils.h" +#include "src/tensor.h" + +namespace mindspore { +namespace lite { + +void LossMonitor::Begin(const session::TrainLoopCallBackData &cb_data) { + if (cb_data.epoch_ == 0) losses_.clear(); +} + +void LossMonitor::EpochBegin(const session::TrainLoopCallBackData &cb_data) { + if (losses_.size() != cb_data.epoch_) { + MS_LOG(WARNING) << "losses array does not match epoch number"; + } else { + losses_.push_back(std::make_pair(cb_data.epoch_, 0.0)); + } +} + +int LossMonitor::EpochEnd(const session::TrainLoopCallBackData &cb_data) { + if (cb_data.step_ > 0) losses_.at(cb_data.epoch_).second /= static_cast(cb_data.step_); + if ((cb_data.epoch_ + 1) % print_every_n_ == 0) { + std::cout << cb_data.epoch_ + 1 << ":\tLoss is " << losses_.at(cb_data.epoch_).second << std::endl; + } + return mindspore::session::RET_CONTINUE; +} + +void LossMonitor::StepEnd(const session::TrainLoopCallBackData &cb_data) { + auto outputs = cb_data.session_->GetOutputs(); + for (auto it = outputs.begin(); it != outputs.end(); ++it) { + if (it->second->ElementsNum() == 1) { + auto loss = reinterpret_cast(it->second->MutableData()); + losses_.at(cb_data.epoch_).second += loss[0]; + return; + } + } + MS_LOG(WARNING) << "Model does not have a loss output tensor of size 1"; +} + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/train/lr_scheduler.cc b/mindspore/lite/src/train/lr_scheduler.cc new file mode 100644 index 0000000000..9b6f9b30f2 --- /dev/null +++ b/mindspore/lite/src/train/lr_scheduler.cc @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "include/train/lr_scheduler.h" +#include +#include +#include +#include +#include +#include +#include +#include "include/errorcode.h" +#include "include/train_session.h" +#include "src/common/utils.h" +#include "src/tensor.h" + +namespace mindspore { +namespace lite { + +int MultiplicativeLRLambda(float *lr, int epoch, void *lr_cb_data) { + if ((lr == nullptr) || (lr_cb_data == nullptr)) { + MS_LOG(ERROR) << "nullptr passed as input to MultiplicativeLRLambda"; + return DONT_UPDATE_LR; + } + float mult = *(static_cast(lr_cb_data)); + *lr = *lr * mult; + return UPDATE_LR; +} + +int StepLRLambda(float *lr, int epoch, void *lr_cb_data) { + if ((lr == nullptr) || (lr_cb_data == nullptr)) { + MS_LOG(ERROR) << "nullptr passed as input to MultiplicativeLRLambda"; + return DONT_UPDATE_LR; + } + struct StepLRLambda *step_lr_data = (static_cast(lr_cb_data)); + if (((epoch + 1) % step_lr_data->step_size) == 0) { + *lr = *lr * step_lr_data->gamma; + return UPDATE_LR; + } + return DONT_UPDATE_LR; +} + +LRScheduler::LRScheduler(LR_Lambda lambda_func, void *lr_cb_data, int step) + : lambda_func_(lambda_func), lr_data_(lr_cb_data), step_(step) {} + +int LRScheduler::EpochEnd(const session::TrainLoopCallBackData &cb_data) { + if (((cb_data.epoch_ + 1) % step_) == 0) { + float lr = cb_data.session_->GetLearningRate(); + int update = lambda_func_(&lr, cb_data.epoch_, lr_data_); + if (update == UPDATE_LR) { + int ret = cb_data.session_->SetLearningRate(lr); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Error setting Leraning rate in train session"; + return mindspore::session::RET_EXIT; + } + } + } + return mindspore::session::RET_CONTINUE; +} + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/train/optimizer_kernel.h b/mindspore/lite/src/train/optimizer_kernel.h new file mode 100644 index 0000000000..d73fc8c1ee --- /dev/null +++ b/mindspore/lite/src/train/optimizer_kernel.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_KERNEL_H_ +#define MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_KERNEL_H_ +#include +#include "src/lite_kernel.h" +namespace mindspore::kernel { + +class OptimizerKernel : public LiteKernel { + public: + OptimizerKernel() = default; + OptimizerKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const lite::PrimitiveC *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} + ~OptimizerKernel() = default; + virtual int SetLearningRate(float lr) = 0; + virtual float GetLearningRate() = 0; +}; + +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_TRAIN_OPTIMIZER_KERNEL_H_ diff --git a/mindspore/lite/src/train/train_loop.cc b/mindspore/lite/src/train/train_loop.cc new file mode 100644 index 0000000000..727351e89f --- /dev/null +++ b/mindspore/lite/src/train/train_loop.cc @@ -0,0 +1,99 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/train/train_loop.h" +#include +#include +#include +#include +#include +#include +#include +#include "include/errorcode.h" +#include "include/train_session.h" +#include "src/common/utils.h" +#include "src/tensor.h" +#include "src/train/loss_kernel.h" +#include "src/train/optimizer_kernel.h" +#include "src/sub_graph_kernel.h" +#include "src/train/train_populate_parameter.h" +#include "src/runtime/runtime_api.h" +#include "src/executor.h" +#include "src/kernel_registry.h" +#include "src/runtime/kernel/arm/fp32_grad/convolution.h" + +namespace mindspore { +namespace lite { + +using session::RET_CONTINUE; +using session::RET_EXIT; +using session::RET_STOP_TRAINING; + +TrainLoop::~TrainLoop() { + if (train_session_ != nullptr) delete train_session_; +} + +int TrainLoop::Train(int epochs, std::vector cbs) { + train_session_->Train(); + session::TrainLoopCallBackData cb_data(true, epoch_, train_session_, this); + + for (auto cb : cbs) cb->Begin(cb_data); + + int steps_in_epoch = 1; // should be data_size/batch_size + for (int i = 0; i < epochs; i++) { + cb_data.epoch_ = epoch_++; + for (auto cb : cbs) cb->EpochBegin(cb_data); + + for (int s = 0; s < steps_in_epoch; s++) { + cb_data.step_ = s; + for (auto cb : cbs) cb->StepBegin(cb_data); + train_session_->RunGraph(before_cb_, after_cb_); + for (auto cb : cbs) cb->StepEnd(cb_data); + } + + int break_loop = false; + for (auto cb : cbs) { + int ret = cb->EpochEnd(cb_data); + if (ret != RET_CONTINUE) { + if (ret == RET_EXIT) { + MS_LOG(ERROR) << "Error in TrainLoop callback"; + return RET_ERROR; + } + if (ret == RET_STOP_TRAINING) { + break_loop = true; + } + } + } + if (break_loop) { + break; + } + } + + for (auto cb : cbs) cb->End(cb_data); + return RET_OK; +} + +} // namespace lite + +session::TrainLoop *session::TrainLoop::CreateTrainLoop(const std::string &model_filename, lite::Context *context, + int batch_size) { + auto train_session = session::TrainSession::CreateSession(model_filename, context); + auto loop = new (std::nothrow) lite::TrainLoop(train_session); + + return loop; +} + +} // namespace mindspore diff --git a/mindspore/lite/src/train/train_loop.h b/mindspore/lite/src/train/train_loop.h new file mode 100644 index 0000000000..01559feadb --- /dev/null +++ b/mindspore/lite/src/train/train_loop.h @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_LOOP_H_ +#define MINDSPORE_LITE_SRC_TRAIN_TRAIN_LOOP_H_ +#include +#include +#include +#include +#include "src/ops/primitive_c.h" +#include "include/train/train_loop.h" +#include "include/train_session.h" + +namespace mindspore { +namespace lite { + +class TrainLoop : virtual public session::TrainLoop { + public: + explicit TrainLoop(session::TrainSession *session) : train_session_(session) {} + + session::TrainSession *train_session() override { return train_session_; } + + int Reset() override { + epoch_ = 0; + return RET_OK; + } + + virtual ~TrainLoop(); + + int SetKernelCallBack(const KernelCallBack &before, const KernelCallBack &after) override { + before_cb_ = before; + after_cb_ = after; + return RET_OK; + } + + int Train(int epochs, std::vector cbs) override; + + protected: + session::TrainSession *train_session_ = nullptr; + unsigned int epoch_ = 0; + KernelCallBack before_cb_ = nullptr; + KernelCallBack after_cb_ = nullptr; + int batch_size; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_TRAIN_TRAIN_LOOP_H_ diff --git a/mindspore/lite/src/train/train_populate_parameter.cc b/mindspore/lite/src/train/train_populate_parameter.cc index 892a47dffd..333e76a8dd 100644 --- a/mindspore/lite/src/train/train_populate_parameter.cc +++ b/mindspore/lite/src/train/train_populate_parameter.cc @@ -15,6 +15,7 @@ */ #include "src/train/train_populate_parameter.h" +#include #include "src/ops/populate/populate_register.h" #include "src/ops/pooling_grad.h" #include "nnacl/pooling_parameter.h" @@ -517,12 +518,15 @@ OpParameter *PopulateArithmeticGradParameter(const mindspore::lite::PrimitiveC * arithmetic_param->broadcasting_ = ((lite::ArithmeticGrad *)primitive)->Broadcasting(); arithmetic_param->ndim_ = ((lite::ArithmeticGrad *)primitive)->NDims(); - auto tmp_shape = ((lite::ArithmeticGrad *)primitive)->x1Shape(); - memcpy(arithmetic_param->in_shape0_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); - tmp_shape = ((lite::ArithmeticGrad *)primitive)->x2Shape(); - memcpy(arithmetic_param->in_shape1_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); - tmp_shape = ((lite::ArithmeticGrad *)primitive)->dyShape(); - memcpy(arithmetic_param->out_shape_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); + auto shape = ((lite::ArithmeticGrad *)primitive)->x1Shape(); + auto source = static_cast(shape.data()); + std::copy(source, source + shape.size(), arithmetic_param->in_shape0_); + shape = ((lite::ArithmeticGrad *)primitive)->x2Shape(); + source = static_cast(shape.data()); + std::copy(source, source + shape.size(), arithmetic_param->in_shape1_); + shape = ((lite::ArithmeticGrad *)primitive)->dyShape(); + source = static_cast(shape.data()); + std::copy(source, source + shape.size(), arithmetic_param->out_shape_); return reinterpret_cast(arithmetic_param); } diff --git a/mindspore/lite/src/train/train_session.cc b/mindspore/lite/src/train/train_session.cc index 4e2941e9c6..0f7bcc1b7d 100644 --- a/mindspore/lite/src/train/train_session.cc +++ b/mindspore/lite/src/train/train_session.cc @@ -26,6 +26,7 @@ #include "src/common/utils.h" #include "src/tensor.h" #include "src/train/loss_kernel.h" +#include "src/train/optimizer_kernel.h" #include "src/sub_graph_kernel.h" #include "src/train/train_populate_parameter.h" #include "src/runtime/runtime_api.h" @@ -49,10 +50,8 @@ TrainSession::TrainSession() { kernel::PopulateTrainParameters(); } std::vector TrainSession::ReplaceOps() { const std::vector replace = { - {{mindspore::kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, mindspore::schema::PrimitiveType_Conv2D}, - mindspore::kernel::CpuConvTrainFp32KernelCreator}, - {{mindspore::kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, mindspore::schema::PrimitiveType_DepthwiseConv2D}, - mindspore::kernel::CpuConvTrainFp32KernelCreator}}; + // currently no ops are Hijacked by TrainSession + }; mindspore::lite::KernelRegistry *reg = mindspore::lite::KernelRegistry::GetInstance(); std::vector results; for (auto v : replace) { @@ -98,7 +97,7 @@ int TrainSession::CompileTrainGraph(mindspore::lite::TrainModel *model) { RestoreOps(restore); CompileTrainKernels(); // Prepare a list of train kernels CompileInferenceKernels(); // Prepare a list of eval kernels - CompileOptimizedKernels(); // Prepare a list of kenels which are optimized (weight update step) + CompileOptimizedKernels(); // Prepare a list of kernels which are optimized (weight update step) CompileTrainOutputs(); // prepare outputs in train mode CompileEvalOutputs(); // prepare outputs in eval mode AllocWorkSpace(); @@ -302,6 +301,30 @@ void TrainSession::CompileOptimizedKernels() { } } +int TrainSession::SetLearningRate(float learning_rate) { + for (auto kernel : this->train_kernels_) { + if (IsOptimizer(kernel)) { + auto optimizer = reinterpret_cast(kernel); + auto ret = optimizer->SetLearningRate(learning_rate); + if (ret != RET_OK) { + MS_LOG(ERROR) << kernel->name() << " failed to set learning rate"; + return RET_ERROR; + } + } + } + return RET_OK; +} + +float TrainSession::GetLearningRate() { + for (auto kernel : this->train_kernels_) { + if (IsOptimizer(kernel)) { + auto optimizer = reinterpret_cast(kernel); + return optimizer->GetLearningRate(); + } + } + return 0.0; +} + bool TrainSession::IsLossKernel(const kernel::LiteKernel *kernel) const { return (kernel->Type() == schema::PrimitiveType_SoftmaxCrossEntropy || kernel->Type() == schema::PrimitiveType_SparseSoftmaxCrossEntropy || diff --git a/mindspore/lite/src/train/train_session.h b/mindspore/lite/src/train/train_session.h index fbeb7ed204..ea1f2e7ecd 100644 --- a/mindspore/lite/src/train/train_session.h +++ b/mindspore/lite/src/train/train_session.h @@ -42,7 +42,6 @@ namespace mindspore { namespace lite { - using CreatorOp = std::tuple; class TrainSession : virtual public session::TrainSession, virtual public lite::LiteSession { public: @@ -59,6 +58,8 @@ class TrainSession : virtual public session::TrainSession, virtual public lite:: int Train() override; int Eval() override; + int SetLearningRate(float learning_rate) override; + float GetLearningRate() override; void BindThread(bool if_bind) override { return lite::LiteSession::BindThread(if_bind); } std::vector GetInputs() const override { return lite::LiteSession::GetInputs(); } @@ -80,6 +81,10 @@ class TrainSession : virtual public session::TrainSession, virtual public lite:: return lite::RET_ERROR; } + std::unordered_map GetPredictions() const override { + return eval_output_tensor_map_; + } + protected: void AllocWorkSpace(); bool IsLossKernel(const kernel::LiteKernel *kernel) const; diff --git a/mindspore/lite/test/models_ms_train.cfg b/mindspore/lite/test/models_ms_train.cfg index 91ef4d47fd..b17284a7d7 100644 --- a/mindspore/lite/test/models_ms_train.cfg +++ b/mindspore/lite/test/models_ms_train.cfg @@ -1,11 +1,13 @@ mini_alexnet -#mobilenetv1 +# mobilenetv1 mobilenetv2 mobilenetv3 lenet effnet -effnet_tune +# effnet_tune # lenetv1 # resnet -# effnetv1 +# googlenet +# densenet +# one_net #LAST \ No newline at end of file diff --git a/mindspore/lite/test/run_net_train.sh b/mindspore/lite/test/run_net_train.sh index a5435384fe..c161801a5e 100755 --- a/mindspore/lite/test/run_net_train.sh +++ b/mindspore/lite/test/run_net_train.sh @@ -83,7 +83,7 @@ function Run_x86() { --inDataFile=${train_io_path}/${model_name}_input1.bin,${train_io_path}/${model_name}_input2.bin \ --expectedDataFile=${train_io_path}/${model_name}_outputs.bin \ --exportFile=${ms_models_path}/${model_name}_train_exported.ms >> "${run_x86_log_file}" \ - --epochs=${epoch_num} + --epochs=${epoch_num} --numThreads=${threads} if [ $? = 0 ]; then run_result='x86: '${model_name}'_train pass'; echo ${run_result} >> ${run_benchmark_train_result_file} else @@ -178,7 +178,8 @@ function Run_arm() { --modelFile=${model_name}_train.ms \ --inDataFile=${tmp_dir}/${model_name}_input1.bin,${tmp_dir}/${model_name}_input2.bin \ --expectedDataFile=${tmp_dir}/${model_name}_outputs.bin \ - --exportFile=${tmp_dir}/${model_name}_train_exported.ms + --exportFile=${tmp_dir}/${model_name}_train_exported.ms \ + --numThreads=${threads} ENDM ) echo "${adb_cmd}" >> ${run_arm_log_file} @@ -221,8 +222,9 @@ echo ${basepath} # Example:run_benchmark_train.sh -r /home/emir/Work/TestingEnv/release -m /home/emir/Work/TestingEnv/train_models -i /home/emir/Work/TestingEnv/train_io -d "8KE5T19620002408" # For running on arm64, use -t to set platform tools path (for using adb commands) epoch_num=1 +threads=1 train_io_path="" -while getopts "r:m:d:i:e:vt:" opt; do +while getopts "r:m:d:i:e:vt:q:" opt; do case ${opt} in r) release_path=${OPTARG} @@ -249,9 +251,13 @@ while getopts "r:m:d:i:e:vt:" opt; do run_valgrind="valgrind --log-file=valgrind.log " echo "Run x86 with valgrind" ;; + q) + threads=${OPTARG} + echo "threads=${threads}" + ;; t) epoch_num=${OPTARG} - echo "train epoch num is ${OPTARG}" + echo "train epoch num is ${epoch_num}" ;; ?) echo "unknown para" diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 998ebfcf3e..1ce2081858 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -511,6 +511,9 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr &input_ano auto valueNode = input_anode->cast(); auto paramTensor = std::make_unique(); auto value = valueNode->value(); +#ifdef SUPPORT_TRAIN + paramTensor->name = valueNode->fullname_with_scope(); +#endif if (value->isa()) { auto valueAbstract = valueNode->abstract(); auto abstractTensor = utils::cast(valueAbstract); @@ -527,7 +530,6 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr &input_ano paramTensor->dims = dims; #ifdef SUPPORT_TRAIN if (paramTensor->dims.size() == 0) paramTensor->dims = {1}; - paramTensor->name = valueNode->fullname_with_scope(); #endif paramTensor->nodeType = schema::NodeType::NodeType_ValueNode; auto data = value->cast(); diff --git a/mindspore/lite/tools/benchmark_train/net_train.cc b/mindspore/lite/tools/benchmark_train/net_train.cc index ec5a14732b..2a70bb3718 100644 --- a/mindspore/lite/tools/benchmark_train/net_train.cc +++ b/mindspore/lite/tools/benchmark_train/net_train.cc @@ -135,6 +135,7 @@ int NetTrain::ReadCalibData() { MS_LOG(INFO) << "Start reading calibData file"; std::string tensor_name; + while (!in_file.eof()) { getline(in_file, line); std::stringstream string_line1(line); @@ -189,7 +190,6 @@ int NetTrain::CompareOutput() { MS_ASSERT(tensor->MutableData() != nullptr); auto outputs = tensor->MutableData(); float bias = CompareData(node_or_tensor_name, tensor->shape(), reinterpret_cast(outputs)); - if (bias >= 0) { total_bias += bias; total_size++; @@ -228,7 +228,7 @@ int NetTrain::CompareOutput() { int NetTrain::MarkPerformance() { MS_LOG(INFO) << "Running train loops..."; std::cout << "Running train loops..." << std::endl; - uint64_t time_min = 1000000; + uint64_t time_min = 0xFFFFFFFFFFFFFFFF; uint64_t time_max = 0; uint64_t time_avg = 0; diff --git a/mindspore/lite/tools/benchmark_train/net_train.h b/mindspore/lite/tools/benchmark_train/net_train.h index 4ecd4fd007..f724fb87db 100644 --- a/mindspore/lite/tools/benchmark_train/net_train.h +++ b/mindspore/lite/tools/benchmark_train/net_train.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_TOOLS_NET_TRAIN_NET_TRAIN_H_ -#define MINDSPORE_LITE_TOOLS_NET_TRAIN_NET_TRAIN_H_ +#ifndef MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_TRAIN_H_ +#define MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_TRAIN_H_ #include #include @@ -59,6 +59,7 @@ class MS_API NetTrainFlags : public virtual FlagParser { AddFlag(&NetTrainFlags::warm_up_loop_count_, "warmUpLoopCount", "Run warm up loop", 0); AddFlag(&NetTrainFlags::time_profiling_, "timeProfiling", "Run time profiling", false); AddFlag(&NetTrainFlags::epochs_, "epochs", "Number of training epochs to run", 1); + AddFlag(&NetTrainFlags::num_threads_, "numThreads", "Run threads number", 1); // MarkAccuracy AddFlag(&NetTrainFlags::data_file_, "expectedDataFile", "Expected results data file path", ""); AddFlag(&NetTrainFlags::export_file_, "exportFile", "MS File to export trained model into", ""); @@ -239,4 +240,4 @@ class MS_API NetTrain { int MS_API RunNetTrain(int argc, const char **argv); } // namespace mindspore::lite -#endif // MINDSPORE_LITE_TOOLS_NET_TRAIN_NET_TRAIN_H_ +#endif // MINDSPORE_LITE_TOOLS_BENCHMARK_TRAIN_NET_TRAIN_H_ diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc index 4d433c4e49..b10055cbb4 100644 --- a/mindspore/lite/tools/common/node_util.cc +++ b/mindspore/lite/tools/common/node_util.cc @@ -136,10 +136,10 @@ static const std::vector int8OpList = {schema::PrimitiveT static const std::vector needInsertOpList = { #ifdef SUPPORT_TRAIN - schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, - schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Split, - schema::PrimitiveType_Slice, schema::PrimitiveType_Crop, schema::PrimitiveType_Mul, - schema::PrimitiveType_Add + schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, + schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Split, + schema::PrimitiveType_Crop, schema::PrimitiveType_Mul, schema::PrimitiveType_Add, + schema::PrimitiveType_ActivationGrad #else schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Add,