tod add train loop

pull/11135/head
yoni 4 years ago
parent 8008843562
commit 33d7741904

@ -395,6 +395,7 @@ if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" || "X$ENABLE_GPU" = "
git submodule update --init --recursive akg
fi
build_exit()
{
echo "$@" >&2
@ -596,33 +597,40 @@ build_lite()
build_lite_java_arm64() {
# build mindspore-lite arm64
if [[ "X$INC_BUILD" = "Xoff" ]] || [[ ! -f "${BASEPATH}/output/mindspore-lite-${VERSION_STR}-inference-android-aarch64.tar.gz" ]]; then
JTARBALL=mindspore-lite-${VERSION_STR}-inference-android-aarch64
if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
JTARBALL=mindspore-lite-${VERSION_STR}-train-android-aarch64
fi
if [[ "X$INC_BUILD" = "Xoff" ]] || [[ ! -f "${BASEPATH}/output/${JTARBALL}.tar.gz" ]]; then
build_lite "arm64" "off"
fi
# copy arm64 so
cd ${BASEPATH}/output/
rm -rf mindspore-lite-${VERSION_STR}-inference-android-aarch64
tar -zxvf mindspore-lite-${VERSION_STR}-inference-android-aarch64.tar.gz
rm -rf ${JTARBALL}
tar -zxvf ${JTARBALL}.tar.gz
[ -n "${JAVA_PATH}" ] && rm -rf ${JAVA_PATH}/java/app/libs/arm64-v8a/
mkdir -p ${JAVA_PATH}/java/app/libs/arm64-v8a/
cp ${BASEPATH}/output/mindspore-lite-${VERSION_STR}-inference-android-aarch64/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
echo mindspore-lite-${VERSION_STR}-inference-android-aarch64
[ -n "${VERSION_STR}" ] && rm -rf mindspore-lite-${VERSION_STR}-inference-android-aarch64
cp ${BASEPATH}/output/${JTARBALL}/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/arm64-v8a/
[ -n "${VERSION_STR}" ] && rm -rf ${JTARBALL}
}
build_lite_java_arm32() {
# build mindspore-lite arm32
if [[ "X$INC_BUILD" = "Xoff" ]] || [[ ! -f "${BASEPATH}/output/mindspore-lite-${VERSION_STR}-inference-android-aarch32.tar.gz" ]]; then
JTARBALL=mindspore-lite-${VERSION_STR}-inference-android-aarch32
if [[ "X$SUPPORT_TRAIN" = "Xon" ]]; then
JTARBALL=mindspore-lite-${VERSION_STR}-train-android-aarch32
fi
if [[ "X$INC_BUILD" = "Xoff" ]] || [[ ! -f "${BASEPATH}/output/${JTARBALL}.tar.gz" ]]; then
build_lite "arm32" "off"
fi
# copy arm32 so
cd ${BASEPATH}/output/
rm -rf mindspore-lite-${VERSION_STR}-inference-android-aarch32
tar -zxvf mindspore-lite-${VERSION_STR}-inference-android-aarch32.tar.gz
rm -rf ${JTARBALL}
tar -zxvf ${JTARBALL}.tar.gz
[ -n "${JAVA_PATH}" ] && rm -rf ${JAVA_PATH}/java/app/libs/armeabi-v7a/
mkdir -p ${JAVA_PATH}/java/app/libs/armeabi-v7a/
cp ${BASEPATH}/output/mindspore-lite-${VERSION_STR}-inference-android-aarch32/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
[ -n "${VERSION_STR}" ] && rm -rf mindspore-lite-${VERSION_STR}-inference-android-aarch32
cp ${BASEPATH}/output/${JTARBALL}/lib/libmindspore-lite.so ${JAVA_PATH}/java/app/libs/armeabi-v7a/
[ -n "${VERSION_STR}" ] && rm -rf ${JTARBALL}
}
build_jni_arm64() {
@ -635,7 +643,7 @@ build_jni_arm64() {
-DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="arm64-v8a" -DANDROID_TOOLCHAIN_NAME="aarch64-linux-android-clang" \
-DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \
-DANDROID_STL="c++_static" -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DENABLE_VERBOSE=${ENABLE_VERBOSE} \
-DPLATFORM_ARM64=on "${JAVA_PATH}/java/app/src/main/native"
-DSUPPORT_TRAIN=${SUPPORT_TRAIN} -DPLATFORM_ARM64=on "${JAVA_PATH}/java/app/src/main/native"
make -j$THREAD_NUM
if [[ $? -ne 0 ]]; then
echo "---------------- mindspore lite: build jni arm64 failed----------------"
@ -655,7 +663,7 @@ build_jni_arm32() {
-DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="armeabi-v7a" -DANDROID_TOOLCHAIN_NAME="aarch64-linux-android-clang" \
-DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \
-DANDROID_STL="c++_static" -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DENABLE_VERBOSE=${ENABLE_VERBOSE} \
-DPLATFORM_ARM32=on "${JAVA_PATH}/java/app/src/main/native"
-DSUPPORT_TRAIN=${SUPPORT_TRAIN} -DPLATFORM_ARM32=on "${JAVA_PATH}/java/app/src/main/native"
make -j$THREAD_NUM
if [[ $? -ne 0 ]]; then
echo "---------------- mindspore lite: build jni arm32 failed----------------"

@ -3,7 +3,7 @@ APP:=bin/net_runner
MSLIB:=mindspore-lite
MSDIR:=$(realpath package-$(TARGET)/lib)
SRC:=src/net_runner.cc src/dataset.cc
SRC:=src/net_runner.cc src/dataset.cc src/data_callbacks.cc
OBJ:=$(SRC:.cc=.o)
CFLAGS := -Ofast -std=c++17 \

@ -0,0 +1,43 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_ACCURACY_MONITOR_H_
#define MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_ACCURACY_MONITOR_H_
#include <vector>
#include <string>
#include <utility>
#include <unordered_map>
#include "include/train/train_loop.h"
#include "src/dataset.h"
using GraphPoint = std::pair<int, float>;
class AccuracyMonitor : public mindspore::session::TrainLoopCallBack {
public:
explicit AccuracyMonitor(DataSet *dataset, int check_every_n, int max_steps = -1)
: ds_(dataset), check_every_n_(check_every_n), max_steps_(max_steps) {}
int EpochEnd(const mindspore::session::TrainLoopCallBackData &cb_data) override;
const std::vector<GraphPoint> &GetAccuracyPoints() const { return accuracies_; }
private:
DataSet *ds_;
std::vector<GraphPoint> accuracies_;
int check_every_n_;
int max_steps_;
};
#endif // MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_ACCURACY_MONITOR_H_

@ -0,0 +1,103 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <math.h>
#include <getopt.h>
#include <cstring>
#include <iostream>
#include <fstream>
#include <utility>
#include "src/net_runner.h"
#include "include/context.h"
#include "src/utils.h"
#include "src/data_loader.h"
#include "src/accuracy_monitor.h"
static unsigned int seed = time(NULL);
std::vector<int> FillInputDataUtil(const mindspore::session::TrainLoopCallBackData &cb_data,
const std::vector<DataLabelTuple> &dataset, bool serially) {
static unsigned int idx = 1;
int total_size = dataset.size();
std::vector<int> labels_vec;
auto inputs = cb_data.session_->GetInputs();
char *input_data = reinterpret_cast<char *>(inputs.at(0)->MutableData());
auto labels = reinterpret_cast<float *>(inputs.at(1)->MutableData());
int batch_size = inputs.at(0)->shape()[0];
int num_of_classes = inputs.at(1)->shape()[1];
int data_size = inputs.at(0)->Size() / batch_size;
MS_ASSERT(total_size > 0);
MS_ASSERT(input_data != nullptr);
std::fill(labels, labels + inputs.at(1)->ElementsNum(), 0.f);
for (int i = 0; i < batch_size; i++) {
if (serially) {
idx = ++idx % total_size;
} else {
idx = rand_r(&seed) % total_size;
}
int label = 0;
char *data = nullptr;
std::tie(data, label) = dataset[idx];
std::copy(data, data + data_size, input_data + i * data_size);
labels[i * num_of_classes + label] = 1.0; // Model expects labels in onehot representation
labels_vec.push_back(label);
}
return labels_vec;
}
void DataLoader::StepBegin(const mindspore::session::TrainLoopCallBackData &cb_data) {
FillInputDataUtil(cb_data, ds_->train_data(), false);
}
int AccuracyMonitor::EpochEnd(const mindspore::session::TrainLoopCallBackData &cb_data) {
if ((cb_data.epoch_ + 1) % check_every_n_ != 0) return mindspore::session::RET_CONTINUE;
float accuracy = 0.0;
auto inputs = cb_data.session_->GetInputs();
int batch_size = inputs.at(0)->shape()[0];
int num_of_classes = ds_->num_of_classes();
int tests = ds_->test_data().size() / batch_size;
if (max_steps_ != -1 && tests > max_steps_) tests = max_steps_;
cb_data.session_->Eval();
for (int i = 0; i < tests; i++) {
auto labels = FillInputDataUtil(cb_data, ds_->test_data(), false);
cb_data.session_->RunGraph();
auto outputs = cb_data.session_->GetPredictions();
for (auto it = outputs.begin(); it != outputs.end(); ++it) {
if (it->second->ElementsNum() == batch_size * num_of_classes) {
auto scores = reinterpret_cast<float *>(it->second->MutableData());
for (int b = 0; b < batch_size; b++) {
int max_idx = 0;
float max_score = scores[num_of_classes * b];
for (int c = 1; c < num_of_classes; c++) {
if (scores[num_of_classes * b + c] > max_score) {
max_score = scores[num_of_classes * b + c];
max_idx = c;
}
}
if (labels[b] == max_idx) accuracy += 1.0;
}
break;
}
}
}
accuracy /= static_cast<float>(batch_size * tests);
accuracies_.push_back(std::make_pair(cb_data.epoch_, accuracy));
std::cout << cb_data.epoch_ + 1 << ":\tAccuracy is " << accuracy << std::endl;
cb_data.session_->Train();
return mindspore::session::RET_CONTINUE;
}

@ -0,0 +1,34 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_DATA_LOADER_H_
#define MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_DATA_LOADER_H_
#include <vector>
#include <string>
#include <utility>
#include <unordered_map>
#include "include/train/train_loop.h"
#include "src/dataset.h"
class DataLoader : public mindspore::session::TrainLoopCallBack {
public:
explicit DataLoader(DataSet *dataset) : ds_(dataset) {}
void StepBegin(const mindspore::session::TrainLoopCallBackData &cb_data) override;
private:
DataSet *ds_;
};
#endif // MINDSPORE_LITE_EXAMPLES_TRAIN_LENET_SRC_DATA_LOADER_H_

@ -20,10 +20,21 @@
#include <cstring>
#include <iostream>
#include <fstream>
#include <utility>
#include "include/context.h"
#include "include/train/loss_monitor.h"
#include "include/train/ckpt_saver.h"
#include "include/train/lr_scheduler.h"
#include "include/train/classification_train_accuracy_monitor.h"
#include "src/utils.h"
#include "src/data_loader.h"
#include "src/accuracy_monitor.h"
using mindspore::session::TrainLoopCallBack;
using mindspore::session::TrainLoopCallBackData;
static unsigned int seed = time(NULL);
unsigned int NetRunner::seed_ = time(NULL);
// Definition of callback function after forwarding operator.
bool after_callback(const std::vector<mindspore::tensor::MSTensor *> &after_inputs,
const std::vector<mindspore::tensor::MSTensor *> &after_outputs,
@ -54,15 +65,18 @@ bool after_callback(const std::vector<mindspore::tensor::MSTensor *> &after_inpu
}
NetRunner::~NetRunner() {
if (session_ != nullptr) delete session_;
if (loop_ != nullptr) delete loop_;
}
void NetRunner::InitAndFigureInputs() {
mindspore::lite::Context context;
context.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = mindspore::lite::NO_BIND;
context.thread_num_ = 1;
context.device_list_[0].device_info_.cpu_device_info_.enable_float16_ = false;
context.device_list_[0].device_type_ = mindspore::lite::DT_CPU;
context.thread_num_ = 2;
session_ = mindspore::session::TrainSession::CreateSession(ms_file_, &context);
loop_ = mindspore::session::TrainLoop::CreateTrainLoop(ms_file_, &context);
session_ = loop_->train_session();
MS_ASSERT(nullptr != session_);
auto inputs = session_->GetInputs();
@ -76,71 +90,10 @@ void NetRunner::InitAndFigureInputs() {
}
}
mindspore::tensor::MSTensor *NetRunner::SearchOutputsForSize(size_t size) const {
auto outputs = session_->GetOutputs();
for (auto it = outputs.begin(); it != outputs.end(); ++it) {
if (it->second->ElementsNum() == size) return it->second;
}
std::cout << "Model does not have an output tensor with size " << size << std::endl;
return nullptr;
}
std::vector<int> NetRunner::FillInputData(const std::vector<DataLabelTuple> &dataset, bool serially) const {
std::vector<int> labels_vec;
static unsigned int idx = 1;
int total_size = dataset.size();
auto inputs = session_->GetInputs();
char *input_data = reinterpret_cast<char *>(inputs.at(data_index_)->MutableData());
auto labels = reinterpret_cast<float *>(inputs.at(label_index_)->MutableData());
MS_ASSERT(total_size > 0);
MS_ASSERT(input_data != nullptr);
std::fill(labels, labels + inputs.at(label_index_)->ElementsNum(), 0.f);
for (int i = 0; i < batch_size_; i++) {
if (serially) {
idx = ++idx % total_size;
} else {
idx = rand_r(&seed_) % total_size;
}
int label = 0;
char *data = nullptr;
std::tie(data, label) = dataset[idx];
std::memcpy(input_data + i * data_size_, data, data_size_);
labels[i * num_of_classes_ + label] = 1.0; // Model expects labels in onehot representation
labels_vec.push_back(label);
}
return labels_vec;
}
float NetRunner::CalculateAccuracy(int max_tests) const {
float accuracy = 0.0;
const std::vector<DataLabelTuple> test_set = ds_.test_data();
int tests = test_set.size() / batch_size_;
if (max_tests != -1 && tests < max_tests) tests = max_tests;
session_->Eval();
for (int i = 0; i < tests; i++) {
auto labels = FillInputData(test_set, (max_tests == -1));
session_->RunGraph();
auto outputsv = SearchOutputsForSize(batch_size_ * num_of_classes_);
MS_ASSERT(outputsv != nullptr);
auto scores = reinterpret_cast<float *>(outputsv->MutableData());
for (int b = 0; b < batch_size_; b++) {
int max_idx = 0;
float max_score = scores[num_of_classes_ * b];
for (int c = 0; c < num_of_classes_; c++) {
if (scores[num_of_classes_ * b + c] > max_score) {
max_score = scores[num_of_classes_ * b + c];
max_idx = c;
}
}
if (labels[b] == max_idx) accuracy += 1.0;
}
}
session_->Train();
accuracy /= static_cast<float>(batch_size_ * tests);
return accuracy;
float NetRunner::CalculateAccuracy(int max_tests) {
AccuracyMonitor test_am(&ds_, 1, max_tests);
test_am.EpochEnd(TrainLoopCallBackData(true, 0, session_, loop_));
return 0.0;
}
int NetRunner::InitDB() {
@ -155,35 +108,17 @@ int NetRunner::InitDB() {
return ret;
}
float NetRunner::GetLoss() const {
auto outputsv = SearchOutputsForSize(1); // Search for Loss which is a single value tensor
MS_ASSERT(outputsv != nullptr);
auto loss = reinterpret_cast<float *>(outputsv->MutableData());
return loss[0];
}
int NetRunner::TrainLoop() {
session_->Train();
float min_loss = 1000.;
float max_acc = 0.;
for (int i = 0; i < cycles_; i++) {
FillInputData(ds_.train_data());
session_->RunGraph(nullptr, verbose_ ? after_callback : nullptr);
float loss = GetLoss();
if (min_loss > loss) min_loss = loss;
if (save_checkpoint_ != 0 && (i + 1) % save_checkpoint_ == 0) {
auto cpkt_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained_" + std::to_string(i + 1) + ".ms";
session_->SaveToFile(cpkt_fn);
}
struct mindspore::lite::StepLRLambda step_lr_lambda(100, 0.9);
mindspore::lite::LRScheduler step_lr_sched(mindspore::lite::StepLRLambda, static_cast<void *>(&step_lr_lambda), 100);
if ((i + 1) % 100 == 0) {
float acc = CalculateAccuracy(10);
if (max_acc < acc) max_acc = acc;
std::cout << i + 1 << ":\tLoss is " << std::setw(7) << loss << " [min=" << min_loss << "] "
<< " max_acc=" << max_acc << std::endl;
}
}
mindspore::lite::LossMonitor lm(100);
// mindspore::lite::ClassificationTrainAccuracyMonitor am(10);
mindspore::lite::CkptSaver cs(1000, std::string("lenet"));
AccuracyMonitor test_am(&ds_, 500, 10);
DataLoader dl(&ds_);
loop_->Train(cycles_, std::vector<TrainLoopCallBack *>{&dl, &lm, &test_am, &cs, &step_lr_sched});
return 0;
}
@ -194,8 +129,7 @@ int NetRunner::Main() {
TrainLoop();
float acc = CalculateAccuracy();
std::cout << "accuracy = " << acc << std::endl;
CalculateAccuracy();
if (cycles_ > 0) {
auto trained_fn = ms_file_.substr(0, ms_file_.find_last_of('.')) + "_trained_" + std::to_string(cycles_) + ".ms";

@ -23,6 +23,7 @@
#include <vector>
#include <string>
#include "include/train_session.h"
#include "include/train/train_loop.h"
#include "include/ms_tensor.h"
#include "src/dataset.h"
@ -38,12 +39,13 @@ class NetRunner {
int InitDB();
int TrainLoop();
std::vector<int> FillInputData(const std::vector<DataLabelTuple> &dataset, bool is_train_set = false) const;
float CalculateAccuracy(int max_tests = -1) const;
float CalculateAccuracy(int max_tests = -1);
float GetLoss() const;
mindspore::tensor::MSTensor *SearchOutputsForSize(size_t size) const;
DataSet ds_;
mindspore::session::TrainSession *session_ = nullptr;
mindspore::session::TrainLoop *loop_ = nullptr;
std::string ms_file_ = "";
std::string data_dir_ = "";

@ -17,9 +17,10 @@
#include "src/net_runner.h"
#include <math.h>
#include <getopt.h>
#include <algorithm>
#include <cstring>
#include <iostream>
#include <fstream>
#include <iostream>
#include "include/context.h"
#include "src/utils.h"
@ -113,7 +114,7 @@ std::vector<int> NetRunner::FillInputData(const std::vector<DataLabelTuple> &dat
int label = 0;
char *data = nullptr;
std::tie(data, label) = dataset[idx];
std::memcpy(input_data + i * data_size_, data, data_size_);
std::copy(data, data + data_size, input_data + i * data_size);
labels[i * num_of_classes_ + label] = 1.0; // Model expects labels in onehot representation
labels_vec.push_back(label);
}

@ -0,0 +1,51 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_CKPT_SAVER_H_
#define MINDSPORE_LITE_INCLUDE_TRAIN_CKPT_SAVER_H_
#include <stdio.h>
#include <vector>
#include <string>
#include <utility>
#include <unordered_map>
#include "include/train/train_loop.h"
using GraphPoint = std::pair<int, float>;
namespace mindspore {
namespace lite {
class CkptSaver : public session::TrainLoopCallBack {
public:
CkptSaver(int save_every_n, const std::string &filename_prefix)
: save_every_n_(save_every_n), filename_prefix_(filename_prefix) {}
int EpochEnd(const session::TrainLoopCallBackData &cb_data) override {
if ((cb_data.epoch_ + 1) % save_every_n_ == 0) {
auto cpkt_fn = filename_prefix_ + "_trained_" + std::to_string(cb_data.epoch_ + 1) + ".ms";
remove(cpkt_fn.c_str());
cb_data.session_->SaveToFile(cpkt_fn);
}
return session::RET_CONTINUE;
}
private:
int save_every_n_;
std::string filename_prefix_;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_CKPT_SAVER_H_

@ -0,0 +1,48 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_CLASSIFICATION_TRAIN_ACCURACY_MONITOR_H_
#define MINDSPORE_LITE_INCLUDE_TRAIN_CLASSIFICATION_TRAIN_ACCURACY_MONITOR_H_
#include <vector>
#include <string>
#include <utility>
#include <climits>
#include <unordered_map>
#include "include/train/train_loop.h"
using GraphPoint = std::pair<int, float>;
namespace mindspore {
namespace lite {
class ClassificationTrainAccuracyMonitor : public session::TrainLoopCallBack {
public:
explicit ClassificationTrainAccuracyMonitor(int print_every_n = INT_MAX) : print_every_n_(print_every_n) {}
virtual ~ClassificationTrainAccuracyMonitor() = default;
void Begin(const session::TrainLoopCallBackData &cb_data) override;
void EpochBegin(const session::TrainLoopCallBackData &cb_data) override;
int EpochEnd(const session::TrainLoopCallBackData &cb_data) override;
void StepEnd(const session::TrainLoopCallBackData &cb_data) override;
const std::vector<GraphPoint> &GetAccuracyPoints() const { return accuracies_; }
private:
std::vector<GraphPoint> accuracies_;
int print_every_n_ = 0;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_CLASSIFICATION_TRAIN_ACCURACY_MONITOR_H_

@ -0,0 +1,47 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_LOSS_MONITOR_H_
#define MINDSPORE_LITE_INCLUDE_TRAIN_LOSS_MONITOR_H_
#include <vector>
#include <string>
#include <utility>
#include <climits>
#include <unordered_map>
#include "include/train/train_loop_callback.h"
using GraphPoint = std::pair<int, float>;
namespace mindspore {
namespace lite {
class LossMonitor : public session::TrainLoopCallBack {
public:
explicit LossMonitor(int print_every_n = INT_MAX) : print_every_n_(print_every_n) {}
virtual ~LossMonitor() = default;
void Begin(const session::TrainLoopCallBackData &cb_data) override;
void EpochBegin(const session::TrainLoopCallBackData &cb_data) override;
int EpochEnd(const session::TrainLoopCallBackData &cb_data) override;
void StepEnd(const session::TrainLoopCallBackData &cb_data) override;
const std::vector<GraphPoint> &GetLossPoints() const { return losses_; }
private:
std::vector<GraphPoint> losses_;
int print_every_n_;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_LOSS_MONITOR_H_

@ -0,0 +1,59 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_LR_SCHEDULER_H_
#define MINDSPORE_LITE_INCLUDE_TRAIN_LR_SCHEDULER_H_
#include <vector>
#include <string>
#include <utility>
#include <functional>
#include <unordered_map>
#include "include/train/train_loop_callback.h"
namespace mindspore {
namespace lite {
constexpr int DONT_UPDATE_LR = 0;
constexpr int UPDATE_LR = 1;
using LR_Lambda = std::function<int(float *lr, int epoch, void *cb_data)>;
/// \brief Multiply the LR by a factor of gamma every epoch
int MultiplicativeLRLambda(float *lr, int epoch, void *multiplication);
/// \brief Multiply the LR by a factor of gamma every step_size
int StepLRLambda(float *lr, int epoch, void *step_size);
struct StepLRLambda {
StepLRLambda(int step, float g) : step_size(step), gamma(g) {}
int step_size; // period of LR decay
float gamma; // LR decay factor
};
class LRScheduler : public session::TrainLoopCallBack {
public:
explicit LRScheduler(LR_Lambda lambda_func, void *lr_cb_data = nullptr, int step_ = 1);
virtual ~LRScheduler() = default;
int EpochEnd(const session::TrainLoopCallBackData &cb_data) override;
private:
LR_Lambda lambda_func_;
void *lr_data_ = nullptr;
int step_ = 1;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_LR_SCHEDULER_H_

@ -0,0 +1,69 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_H_
#define MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_H_
#include <vector>
#include <string>
#include <tuple>
#include <unordered_map>
#include "include/train/train_loop_callback.h"
#include "include/train_session.h"
namespace mindspore {
namespace session {
class TrainLoop {
public:
/// \brief Static method to create a TrainLoop object
///
/// \param[in] filename Filename to read flatbuffer from
/// \param[in] context Defines the context of the session to be created
///
/// \return Pointer of MindSpore Lite TrainLoop
static TrainLoop *CreateTrainLoop(const std::string &model_filename, lite::Context *context, int batch_size = -1);
/// \brief Class destructor
virtual ~TrainLoop() = default;
/// \brief Resets the epoch counter
///
/// \return 0 on success or -1 in case of error
virtual int Reset() = 0; // resets the epoch counter to 0.
/// \brief Accessor to the TrainSession
///
/// \return pointer of the train_session
virtual session::TrainSession *train_session() = 0;
/// \brief Accessor to the Session KernelCallbacks
///
/// \param[in] before Define a call_back_function to be called before running each node.
/// \param[in] after Define a call_back_function called after running each node.
///
/// \return 0 on success or -1 in case of error
virtual int SetKernelCallBack(const KernelCallBack &before, const KernelCallBack &after) = 0;
/// \brief Performs the training Loop
///
/// \param[in] epoch The number of epochs to run
/// \param[in] cbs A vector of TrainLoopCallBack objects
///
/// \return 0 on success or -1 in case of error
virtual int Train(int epochs, std::vector<TrainLoopCallBack *> cbs) = 0;
};
} // namespace session
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_H_

@ -0,0 +1,57 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_CALLBACK_H_
#define MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_CALLBACK_H_
#include <vector>
#include <string>
#include <tuple>
#include <unordered_map>
namespace mindspore {
namespace session {
class TrainSession;
class TrainLoop;
struct TrainLoopCallBackData {
TrainLoopCallBackData(bool train_mode, int epoch, TrainSession *session, TrainLoop *loop)
: train_mode_(train_mode), epoch_(epoch), session_(session), loop_(loop) {}
bool train_mode_; /**< training mode of TrainSession object */
unsigned int epoch_; /**< the current training epoch (starts at 0) */
unsigned int step_ = 0; /**< the current step within the epoch */
TrainSession *session_; /**< pointer to the TrainSession */
TrainLoop *loop_;
};
constexpr int RET_CONTINUE = 0;
constexpr int RET_STOP_TRAINING = 1;
constexpr int RET_EXIT = 2;
class TrainLoopCallBack {
public:
virtual ~TrainLoopCallBack() = default;
virtual void Begin(const TrainLoopCallBackData &cb_data) {}
virtual void End(const TrainLoopCallBackData &cb_data) {}
virtual void EpochBegin(const TrainLoopCallBackData &cb_data) {}
virtual int EpochEnd(const TrainLoopCallBackData &cb_data) { return RET_CONTINUE; }
virtual void StepBegin(const TrainLoopCallBackData &cb_data) {}
virtual void StepEnd(const TrainLoopCallBackData &cb_data) {}
};
} // namespace session
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_TRAIN_LOOP_CALLBACK_H_

@ -83,6 +83,23 @@ class TrainSession : public session::LiteSession {
/// \return boolean indication if model is in eval mode
bool IsEval() { return train_mode_ == false; }
/// \brief Sets the Learning Rate of the training
///
/// \param[in] learning_rate to set
///
/// \return STATUS as an error code of the set operation, STATUS is defined in errorcode.h
virtual int SetLearningRate(float learning_rate) = 0;
/// \brief Gets the Learning Rate of the training
///
/// \return learning rate. 0.0 if no optimizer was found
virtual float GetLearningRate() = 0;
/// \brief Get output MindSpore Lite MSTensors of Training model prediction
///
/// \return The map of output tensor name and MindSpore Lite MSTensor.
virtual std::unordered_map<std::string, mindspore::tensor::MSTensor *> GetPredictions() const = 0;
protected:
bool train_mode_ = false;
};

@ -0,0 +1,178 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.lite;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import com.mindspore.lite.config.MSConfig;
public class TrainSession {
static {
System.loadLibrary("mindspore-lite-jni");
}
private long sessionPtr;
public TrainSession() {
this.sessionPtr = 0;
}
public boolean init(String modelFilename, MSConfig config) {
this.sessionPtr = createSession(modelFilename, config.getMSConfigPtr());
return this.sessionPtr != 0;
}
public long getSessionPtr() {
return sessionPtr;
}
public void bindThread(boolean if_bind) {
this.bindThread(this.sessionPtr, if_bind);
}
public boolean runGraph() {
return this.runGraph(this.sessionPtr);
}
public List<MSTensor> getInputs() {
List<Long> ret = this.getInputs(this.sessionPtr);
ArrayList<MSTensor> tensors = new ArrayList<MSTensor>();
for (Long ms_tensor_addr : ret) {
MSTensor msTensor = new MSTensor(ms_tensor_addr);
tensors.add(msTensor);
}
return tensors;
}
public MSTensor getInputsByTensorName(String tensorName) {
Long tensor_addr = this.getInputsByTensorName(this.sessionPtr, tensorName);
if(tensor_addr == null){
return null;
}
MSTensor msTensor = new MSTensor(tensor_addr);
return msTensor;
}
public List<MSTensor> getOutputsByNodeName(String nodeName) {
List<Long> ret = this.getOutputsByNodeName(this.sessionPtr, nodeName);
ArrayList<MSTensor> tensors = new ArrayList<>();
for (Long msTensorAddr : ret) {
MSTensor msTensor = new MSTensor(msTensorAddr);
tensors.add(msTensor);
}
return tensors;
}
public Map<String, MSTensor> getOutputMapByTensor() {
Map<String, Long> ret = this.getOutputMapByTensor(this.sessionPtr);
Map<String, MSTensor> tensorMap = new HashMap<>();
Set<Map.Entry<String, Long>> entrySet = ret.entrySet();
for (Map.Entry<String, Long> entry : entrySet) {
String name = entry.getKey();
Long msTensorAddr = entry.getValue();
tensorMap.put(name, new MSTensor(msTensorAddr));
}
return tensorMap;
}
public List<String> getOutputTensorNames() {
return getOutputTensorNames(this.sessionPtr);
}
public MSTensor getOutputByTensorName(String tensorName) {
Long tensor_addr = getOutputByTensorName(this.sessionPtr, tensorName);
if(tensor_addr == null){
return null;
}
return new MSTensor(tensor_addr);
}
public void free() {
this.free(this.sessionPtr);
this.sessionPtr = 0;
}
public boolean resize(List<MSTensor> inputs, int[][] dims) {
long[] inputs_array = new long[inputs.size()];
for (int i = 0; i < inputs.size(); i++) {
inputs_array[i] = inputs.get(i).getMSTensorPtr();
}
return this.resize(this.sessionPtr, inputs_array, dims);
}
public boolean saveToFile(String modelFilename) {
return this.saveToFile(this.sessionPtr, modelFilename);
}
public boolean train() {
return this.train(this.sessionPtr);
}
public boolean eval() {
return this.eval(this.sessionPtr);
}
public boolean isTrain() {
return this.isTrain(this.sessionPtr);
}
public boolean isEval() {
return this.isEval(this.sessionPtr);
}
public boolean setLearningRate(float learning_rate) {
return this.setLearningRate(this.sessionPtr, learning_rate);
}
private native long createSession(String modelFilename, long msConfigPtr);
private native void bindThread(long sessionPtr, boolean if_bind);
private native boolean runGraph(long sessionPtr);
private native List<Long> getInputs(long sessionPtr);
private native long getInputsByTensorName(long sessionPtr, String tensorName);
private native List<Long> getOutputsByNodeName(long sessionPtr, String nodeName);
private native Map<String, Long> getOutputMapByTensor(long sessionPtr);
private native List<String> getOutputTensorNames(long sessionPtr);
private native long getOutputByTensorName(long sessionPtr, String tensorName);
private native void free(long sessionPtr);
private native boolean resize(long sessionPtr, long[] inputs, int[][] dims);
private native boolean saveToFile(long sessionPtr, String modelFilename);
private native boolean train(long sessionPtr);
private native boolean eval(long sessionPtr);
private native boolean isTrain(long sessionPtr);
private native boolean isEval(long sessionPtr);
private native boolean setLearningRate(long sessionPtr, float learning_rate);
}

@ -7,8 +7,10 @@ set(PLATFORM_ARM "on")
set(MS_VERSION_MAJOR ${MS_VERSION_MAJOR})
set(MS_VERSION_MINOR ${MS_VERSION_MINOR})
set(MS_VERSION_REVISION ${MS_VERSION_REVISION})
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} -DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} \
-DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMS_VERSION_MAJOR=${MS_VERSION_MAJOR} -DMS_VERSION_MINOR=${MS_VERSION_MINOR} \
-DMS_VERSION_REVISION=${MS_VERSION_REVISION}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
#set for cross-compiling toolchain
@ -16,16 +18,16 @@ set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH)
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE BOTH)
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE BOTH)
if (ENABLE_VERBOSE)
if(ENABLE_VERBOSE)
set(CMAKE_VERBOSE_MAKEFILE on)
endif ()
endif()
if (PLATFORM_ARM32)
if(PLATFORM_ARM32)
add_compile_definitions(ENABLE_ARM32)
endif ()
if (PLATFORM_ARM64)
endif()
if(PLATFORM_ARM64)
add_compile_definitions(ENABLE_ARM64)
endif ()
endif()
set(TOP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../../../../../..)
set(LITE_DIR ${TOP_DIR}/mindspore/lite)
@ -40,15 +42,25 @@ include_directories(${LITE_DIR}/build) ## flatbuffers
link_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../../libs/${ANDROID_ABI}/)
add_library(mindspore-lite-jni SHARED
${CMAKE_CURRENT_SOURCE_DIR}/common/jni_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/model.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/version.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/ms_config.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/ms_tensor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/lite_session.cpp
)
set(JNI_SRC
${CMAKE_CURRENT_SOURCE_DIR}/common/jni_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/model.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/version.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/ms_config.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/ms_tensor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/lite_session.cpp
)
if(SUPPORT_TRAIN)
set(JNI_SRC
${JNI_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/runtime/train_session.cpp
)
endif()
add_library(mindspore-lite-jni SHARED ${JNI_SRC})
find_library(log-lib log)
target_link_libraries(mindspore-lite-jni mindspore-lite ${log-lib})
target_link_libraries(mindspore-lite-jni mindspore-lite ${log-lib})

@ -800,8 +800,8 @@ int ElementSubRelu6(const float *in0, const float *in1, float *out, int size) {
int BroadcastDiv(const float *in0, const float *in1, float *tile_in0, float *tile_in1, float *out, int size,
ArithmeticParameter *param) {
TileDimensionsFp32(in0, in1, tile_in0, tile_in0, param);
return ElementDiv(tile_in0, tile_in0, out, size);
TileDimensionsFp32(in0, in1, tile_in0, tile_in1, param);
return ElementDiv(tile_in0, tile_in1, out, size);
}
int ElementDiv(const float *in0, const float *in1, float *out, int size) {

@ -21,32 +21,48 @@
#include "nnacl/errorcode.h"
inline int ReluGrad(float *src0, float *src1, size_t length, float *dst) {
for (size_t i = 0; i < length; ++i) {
if (src1[i] > 0) {
dst[i] = src0[i];
} else {
dst[i] = 0;
}
int i = 0;
#ifdef ENABLE_ARM
float32x4_t zero_4 = vdupq_n_f32(0.0f);
for (; i < length - 4; i += 4) {
float32x4_t src1_4 = vld1q_f32(src1 + i);
float32x4_t src0_4 = vld1q_f32(src0 + i);
uint32x4_t mask_4 = vcgtq_f32(src1_4, zero_4);
float32x4_t dst_4 = vbslq_f32(mask_4, src0_4, zero_4);
vst1q_f32(dst + i, dst_4);
}
#endif
for (; i < length; ++i) {
dst[i] = (src1[i] > 0.0f) ? src0[i] : 0.0f;
}
return NNACL_OK;
}
int Relu6Grad(float *src0, float *src1, size_t length, float *dst) {
for (size_t i = 0; i < length; ++i) {
if (src1[i] > 0.0f && src1[i] <= 6.0f) {
dst[i] = src0[i];
} else {
dst[i] = 0.0f;
}
int i = 0;
#ifdef ENABLE_ARM
float32x4_t zero_4 = vdupq_n_f32(0.0f);
float32x4_t six_4 = vdupq_n_f32(6.0f);
for (; i < length - 4; i += 4) {
float32x4_t src1_4 = vld1q_f32(src1 + i);
float32x4_t src0_4 = vld1q_f32(src0 + i);
float32x4_t max_4 = vmaxq_f32(src1_4, zero_4);
float32x4_t min_max_4 = vminq_f32(max_4, six_4);
uint32x4_t mask_4 = vceqq_f32(min_max_4, src1_4);
float32x4_t dst_4 = vbslq_f32(mask_4, src0_4, zero_4);
vst1q_f32(dst + i, dst_4);
}
#endif
for (; i < length; ++i) {
dst[i] = (src1[i] > 0.0f && src1[i] <= 6.0f) ? src0[i] : 0.0f;
}
return NNACL_OK;
}
int LReluGrad(float *src0, float *src1, size_t length, float *dst, float alpha) {
for (size_t i = 0; i < length; ++i) {
dst[i] = src1[i] > 0.0f ? 1.0f : alpha;
dst[i] = src1[i] > 0.0f ? src0[i] : alpha * src0[i];
}
ElementMul(src0, dst, dst, length);
return NNACL_OK;
}

@ -17,55 +17,36 @@
#include <string.h>
#include "nnacl/fp32_grad/batch_norm.h"
void sumSpatialBatch(const float *in, size_t size, int ch, float *out) {
memset(out, 0, ch * sizeof(float));
for (size_t i = 0; i < size; i++) {
const float *ptr = in + (i * ch);
for (size_t c = 0; c < ch; c++) {
out[c] += ptr[c];
}
void var2Invar(float *save_var, int size, float eps) {
for (int i = 0; i < size; i++) {
save_var[i] = 1.0f / sqrt(save_var[i] + eps);
}
}
void backwardX(const float *in, const float *dout, const float *scale, const size_t size, int channels, float *mean,
float *invar, float *dxhathat_sum, float *dxhat_sum, float *out) {
const float N = (size);
for (size_t i = 0; i < size; i++) {
for (size_t f = 0; f < channels; f++) {
size_t ix = i * channels + f;
float x_hat = (in[ix] - mean[f]) * invar[f];
float dx_hat = dout[ix] * scale[f];
dxhat_sum[f] += dx_hat;
dxhathat_sum[f] += dx_hat * x_hat;
void backwardAll(const float *restrict in, const float *restrict yt, const float *restrict mean,
const float *restrict invar, const float *restrict scale, int size, int ch, float *restrict dxhat_sum,
float *restrict dxhathat_sum, float *restrict dbias, float *restrict dscale, float *restrict dx) {
float N = (float)size;
for (int i = 0; i < size; i++) {
for (int c = 0; c < ch; c++) {
int ix = i * ch + c;
dbias[c] += yt[ix];
// dscale
float x_hat = (in[ix] - mean[c]) * invar[c];
dscale[c] += (yt[ix] * x_hat);
// dx_1
float dx_hat = yt[ix] * scale[c];
dxhat_sum[c] += dx_hat;
dxhathat_sum[c] += dx_hat * x_hat;
}
}
for (size_t i = 0; i < size; i++) {
for (size_t f = 0; f < channels; f++) {
size_t ix = i * channels + f;
float x_hat = (in[ix] - mean[f]) * invar[f];
float dx_hat = dout[ix] * scale[f];
out[ix] = 1.0f / N * (invar[f]) * (N * dx_hat - dxhat_sum[f] - x_hat * dxhathat_sum[f]);
}
}
}
void backwardScale(const float *x, const float *mean, const float *invar, const float *delta, int batch, int n,
int size, float *scale_updates) {
size_t i, b, f;
memset(scale_updates, 0, n * sizeof(float));
for (b = 0; b < batch; ++b) {
for (i = 0; i < size; ++i) {
for (f = 0; f < n; ++f) {
int index = (b * size + i) * n + f;
float x_norm = (x[index] - mean[f]) * invar[f];
scale_updates[f] += (delta[index] * x_norm);
}
for (int i = 0; i < size; i++) {
for (int c = 0; c < ch; c++) {
// dx_2
int ix = i * ch + c;
float x_hat = (in[ix] - mean[c]) * invar[c];
float dx_hat = yt[ix] * scale[c];
dx[ix] = 1.0f / N * (invar[c]) * (N * dx_hat - dxhat_sum[c] - x_hat * dxhathat_sum[c]);
}
}
}
void var2Invar(float *save_var, size_t size, float eps) {
for (size_t i = 0; i < size; i++) {
save_var[i] = 1.0f / sqrt(save_var[i] + eps);
}
}

@ -29,13 +29,9 @@ typedef struct BNGradParameter {
extern "C" {
#endif
void sumSpatialBatch(const float *in, size_t size, int ch, float *out);
void backwardX(const float *in, const float *dout, const float *scale, const size_t size, int channels, float *mean,
float *invar, float *xhat_sum, float *dxhat_sum, float *out);
void backwardScale(const float *x, const float *mean, const float *invar, const float *delta, int batch, int n,
int size, float *scale_updates);
void var2Invar(float *save_var, size_t size, float eps);
void var2Invar(float *save_var, int size, float eps);
void backwardAll(const float *in, const float *yt, const float *mean, const float *invar, const float *scale, int size,
int ch, float *dxhat_sum, float *dxhathat_sum, float *dbias, float *dscale, float *dx);
#ifdef __cplusplus
}
#endif

@ -20,7 +20,7 @@
int BinaryCrossEntropyGrad(const int input_size, const int reduction, const float *input_x, const float *input_y,
const float *weight, const float *dloss, float *dx) {
const float epsilon = 1e-12;
const float epsilon = 1e-12f;
if (reduction == 0) {
for (int i = 0; i < input_size; i++) {
float denominator = MAX(input_x[i] * (1 - input_x[i]), epsilon);

@ -21,7 +21,7 @@
#endif
#include "nnacl/fp32/matmul_fp32.h"
static void addv(const float *restrict v1, float *restrict v2, float beta, int row, int col, int stride) {
void AddMatrix(const float *restrict v1, float *restrict v2, float beta, int row, int col, int stride) {
const float *src_ptr = v1;
float *dst_ptr = v2;
for (int r = 0; r < row; r++) {
@ -86,7 +86,8 @@ static void RowMajor2Row12MajorStride(const float *src_ptr, float *dst_ptr, int
return;
}
static void RowMajor2Col12MajorStride(const float *src_ptr, float *dst_ptr, size_t row, size_t col, int lead) {
static void RowMajor2Col12MajorStride(const float *restrict src_ptr, float *restrict dst_ptr, size_t row, size_t col,
int lead) {
size_t row_up_12 = UP_ROUND(row, C12NUM);
size_t row12 = row / C12NUM * C12NUM;
size_t col4 = col / C4NUM * C4NUM;
@ -549,7 +550,7 @@ void GemmMatmulPlus(int ta, int tb, int M, int N, int K, float alpha, const floa
#else
MatMulOpt(mat_a_input, mat_b_input, output, gcb->bias, gcb->atype, K, M, N, ldc, OutType_Nhwc);
#endif
if (incremental) addv(output, mat_c, beta, M, N, ldc);
if (incremental) AddMatrix(output, mat_c, beta, M, N, ldc);
gcb->mat_a = mat_a_input;
gcb->mat_b = mat_b_input;
}

@ -37,6 +37,7 @@ void GemmMatmul(int ta, int tb, int M, int N, int K, float alpha, const float *m
int ldb, float beta, float *mat_c, int ldc, float *workspace);
int MatSize(int row, int col, int round);
int MatSizeTotal(int row, int col, int deep, int inc);
void AddMatrix(const float *v1, float *v2, float beta, int row, int col, int stride);
#ifdef __cplusplus
}
#endif

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save