!5809 train on device

Merge pull request !5809 from yonibaehr/export
pull/5809/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 5c42198430

@ -69,6 +69,7 @@ class MS_API Model {
/// \brief Free MetaGraph in MindSpore Lite Model.
void FreeMetaGraph();
ModelImpl *model_impl() {return model_impl_;}
protected:
ModelImpl *model_impl_ = nullptr;

@ -0,0 +1,63 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INCLUDE_TRAIN_SESSION_H_
#define MINDSPORE_LITE_INCLUDE_TRAIN_SESSION_H_
#include <vector>
#include <string>
#include <unordered_map>
// #include "include/lite_session.h"
#include "src/lite_session.h"
namespace mindspore {
namespace lite {
class Model;
}
namespace lite::tensor {
class Tensor;
}
namespace session {
class TrainSession : public lite::LiteSession {
public:
TrainSession();
~TrainSession() = default;
int RunGraph(const session::KernelCallBack &before = nullptr,
const session::KernelCallBack &after = nullptr) override;
int CompileGraph(lite::Model *model) override;
virtual void ReplaceOps();
virtual void* ExportToBuf(void* buf, size_t* len) const;
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> GetOutputs() const;
std::vector<tensor::MSTensor *> GetOutputsByName(const std::string &node_name) const;
virtual void train();
bool is_train() { return train_mode_ == true; }
virtual void eval();
bool is_eval() { return train_mode_ == false; }
protected:
bool train_mode_ = false;
lite::Model* model_ = nullptr;
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> ext_output_map_;
// private:
};
} // namespace session
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_SESSION_H_

@ -13,9 +13,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/activation_grad.h"
int ReluGrad(float *src0, float *src1, int length, float *dst) {
#include <math.h>
#include "nnacl/op_base.h"
#include "nnacl/fp32/arithmetic.h"
#include "nnacl/fp32_grad/activation_grad.h"
#include "nnacl/errorcode.h"
inline int ReluGrad(float *src0, float *src1, int length, float *dst) {
for (int i = 0; i < length; ++i) {
dst[i] = src1[i] > 0 ? 1.0f : 0.0f;
}

@ -13,11 +13,11 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string.h>
#include <math.h>
#include <string.h>
#include "nnacl/fp32_grad/batch_norm.h"
static void sumSpatialBatch(const float *in, int size, int ch, float *out) {
void sumSpatialBatch(const float *in, int size, int ch, float *out) {
memset(out, 0, ch * sizeof(float));
for (int i = 0; i < size; i++) {
const float *ptr = in + i * ch;
@ -32,49 +32,53 @@ void scaleBias(const float *scales, int batch, int n, int size, float *output) {
for (int c = 0; c < n; c++) output[i * n + c] *= scales[c];
}
void normalize(const float *x, const float *mean, const float *variance, float eps, int batch, int filters, int spatial,
void normalize(const float *x, const float *mean, const float *invar, int batch, int filters, int spatial,
float *out) {
int b, f, i;
for (b = 0; b < batch; ++b) {
for (i = 0; i < spatial; ++i) {
for (f = 0; f < filters; ++f) {
int index = b * filters * spatial + i * filters + f;
out[index] = (x[index] - mean[f]) / (sqrt(variance[f]) + eps);
out[index] = (x[index] - mean[f]) * invar[f];
}
}
}
}
void backwardScale(const float *x_norm, const float *delta, int batch, int n, int size, float *scale_updates) {
void backwardScale(const float *x, const float *mean, const float *invar, const float *delta, int batch,
int n, int size, float *scale_updates) {
int i, b, f;
memset(scale_updates, 0, n * sizeof(float));
for (b = 0; b < batch; ++b) {
for (i = 0; i < size; ++i) {
for (f = 0; f < n; ++f) {
int index = (b * size + i) * n + f;
scale_updates[f] += delta[index] * x_norm[index];
float x_norm = (x[index] - mean[f]) * invar[f];
scale_updates[f] += delta[index] * x_norm;
}
}
}
}
void meanVar(const float *in, int batch, int spatial, int ch, float *mean, float *var) {
void meanVar(const float *in, int batch, int spatial, int ch, float eps, float *mean, float *invar) {
float N = batch * spatial;
sumSpatialBatch(in, N, ch, mean);
for (int f = 0; f < ch; ++f) mean[f] /= N;
memset(var, 0, ch * sizeof(float));
for (int i = 0; i < N; i++) {
for (int f = 0; f < ch; f++) {
float x = in[i * ch + f];
var[f] += (x - mean[f]) * (x - mean[f]);
for (int f = 0; f < ch; ++f) {
mean[f] /= N;
}
for (int f=0; f< ch; f++) {
float tvar = 0;
for (int i =0; i< N; i++) {
float x = in[i*ch +f];
tvar += (x-mean[f]) *(x-mean[f]);
}
invar[f] = 1.0f/(sqrt(tvar/N+eps));
}
for (int f = 0; f < ch; f++) var[f] /= N;
}
void meanDelta(float *yt, int size, int ch, float eps, float *variance, float *mean_delta) {
void meanDelta(float *yt, int size, int ch, float *invar, float *mean_delta) {
sumSpatialBatch(yt, size, ch, mean_delta);
for (int i = 0; i < ch; i++) mean_delta[i] *= -1.f / sqrt((variance[i] + eps));
for (int i = 0; i < ch; i++) mean_delta[i] *= -invar[i];
}
void meanAdd(const float *x, const float *mean, const float *variance_delta, int batch, int filters, int spatial,
@ -93,8 +97,8 @@ void meanAdd(const float *x, const float *mean, const float *variance_delta, int
}
}
void varianceDelta(const float *x, const float *delta, const float *mean, const float *variance, int batch, int filters,
int spatial, float eps, float *variance_delta) {
void varianceDelta(const float *x, const float *delta, const float *mean, const float *invar, int batch, int filters,
int spatial, float *variance_delta) {
int i, k;
memset(variance_delta, 0, filters * sizeof(float));
for (k = 0; k < batch * spatial; k++) {
@ -103,16 +107,16 @@ void varianceDelta(const float *x, const float *delta, const float *mean, const
variance_delta[i] += delta[index] * (x[index] - mean[i]);
}
}
for (i = 0; i < filters; i++) variance_delta[i] *= -.5 * pow(variance[i] + eps, (-3.f / 2.f));
for (i = 0; i < filters; i++) variance_delta[i] *= -.5 * 1.0f/(invar[i]*invar[i]*invar[i]);
}
void NormalizeDelta(const float *x, const float *mean, const float *variance, const float *mean_delta,
const float *variance_delta, int batch, int filters, int spatial, float eps, float *delta) {
void NormalizeDelta(const float *x, const float *mean, const float *invar, const float *mean_delta,
const float *variance_delta, int batch, int filters, int spatial, float *delta) {
int f, k;
for (k = 0; k < batch * spatial; k++) {
for (f = 0; f < filters; f++) {
int index = k * filters + f;
delta[index] = delta[index] * 1. / (sqrt(variance[f] + eps)) +
delta[index] = delta[index] * invar[f] +
variance_delta[f] * 2. * (x[index] - mean[f]) / (spatial * batch) +
mean_delta[f] / (spatial * batch);
}

@ -17,28 +17,33 @@
#ifndef MINDSPORE_LITE_NNACL_FP32_BATCH_NORM_H_
#define MINDSPORE_LITE_NNACL_FP32_BATCH_NORM_H_
typedef struct bnParameter {
int batch;
int channels;
int spatial;
float eps;
} bnParameter;
#include "nnacl/op_base.h"
typedef struct BNGradParameter {
OpParameter op_parameter_;
float epsilon_;
float momentum_;
} BNGradParameter;
#ifdef __cplusplus
extern "C" {
#endif
void sumSpatialBatch(const float *in, int size, int ch, float *out);
void scaleBias(const float *scales, int batch, int n, int size, float *output);
void normalize(const float *x, const float *mean, const float *variance, float eps, int batch, int filters, int spatial,
void normalize(const float *x, const float *mean, const float *invar, int batch, int filters, int spatial,
float *out);
void backwardScale(const float *x_norm, const float *delta, int batch, int n, int size, float *scale_updates);
void meanVar(const float *in, int batch, int size, int ch, float *mean, float *var);
void meanDelta(float *yt, int size, int ch, float eps, float *variance, float *mean_delta);
void varianceDelta(const float *x, const float *delta, const float *mean, const float *variance, int batch, int ch,
int spatial, float eps, float *variance_delta);
void backwardScale(const float *x, const float *mean, const float *invar, const float *delta, int batch,
int n, int size, float *scale_updates);
void meanVar(const float *in, int batch, int size, int ch, float eps, float *mean, float *invar);
void meanDelta(float *yt, int size, int ch, float *invar, float *mean_delta);
void varianceDelta(const float *x, const float *delta, const float *mean, const float *invar, int batch, int ch,
int spatial, float *variance_delta);
void meanAdd(const float *x, const float *mean, const float *variance_delta, int batch, int filters, int spatial,
float *mean_add, float *mean_delta);
void NormalizeDelta(const float *x, const float *mean, const float *variance, const float *mean_delta,
const float *variance_delta, int batch, int filters, int spatial, float eps, float *delta);
void NormalizeDelta(const float *x, const float *mean, const float *invar, const float *mean_delta,
const float *variance_delta, int batch, int filters, int spatial, float *delta);
#ifdef __cplusplus
}
#endif

@ -125,9 +125,9 @@ void im2row_hwc(const float *in_data, float *data_row, ConvParameter *conv_param
}
void col2im_hwc(const float *data_col, float *data_im, ConvParameter *conv_param) {
const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_w_;
const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_l_;
// const int pad_right = /*conv_param->pad_r_*/conv_param->pad_w_;
const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_h_;
const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_u_;
// const int pad_down = /*conv_param->pad_d/*/conv_param->pad_h_;
const int stride_h = conv_param->stride_h_;

@ -13,7 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cstdint>
#include <stdint.h>
#include <float.h>
#include "nnacl/fp32_grad/pooling_grad.h"
void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param) {
@ -31,33 +32,37 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter
int output_batch = pooling_param->output_batch_;
const float *inPtr = NULL;
for (int i = 0; i < output_h * output_w * channel * output_batch; i++) output_ptr[i] = 0.0;
// for (int i = 0; i < output_h * output_w * channel * output_batch; i++) output_ptr[i] = 0.0;
for (int i = 0; i < in_h * in_w * channel * output_batch; i++) output_ptr[i] = 0.0;
float kk = (float)(win_h * win_w);
for (uint16_t ib = 0; ib < output_batch; ib++) {
float *out;
out = &output_ptr[(ib * output_h * output_w)];
inPtr = (float *)(&input_ptr[(ib * in_h * in_w)]);
// out = &output_ptr[(ib * output_h * output_w)];
out = &output_ptr[(ib * in_h * in_w * channel)];
// inPtr = (float *)(&input_ptr[(ib * in_h * in_w)]);
inPtr = (float *)(&input_ptr[(ib * output_h * output_w * channel)]);
if (1) { // in->layout() == Tensor::nhwc)
// iterate over yt
for (uint16_t yh = 0; yh < in_h; yh++) {
for (uint16_t yw = 0; yw < in_w; yw++) {
for (uint16_t yh = 0; yh < output_h; yh++) {
for (uint16_t yw = 0; yw < output_w; yw++) {
for (uint16_t ic = 0; ic < channel; ic++) {
int idx = (yw + yh * in_w) * channel + ic; // (ic*in_h*in_w) + (in_w*yh) + yw;
int idx = (yw + yh * output_w) * channel + ic; // (ic*in_h*in_w) + (in_w*yh) + yw;
float delta = inPtr[idx] / kk;
for (int32_t kh = 0; kh < win_h; kh++) {
int xh = yh * stride_h + kh - pad_h;
if ((xh < 0) || (xh >= output_h)) {
if ((xh < 0) || (xh >= in_h)) {
continue;
}
for (int32_t kw = 0; kw < win_w; kw++) {
int xw = yw * stride_w + kw - pad_w;
if ((xw < 0) || (xw >= output_w)) {
if ((xw < 0) || (xw >= in_w)) {
continue;
}
// out[(ic*output_h*output_w) + (xh*output_w) + xw] += delta;
out[(xw + output_w * xh) * channel + ic] += delta;
// out[(xw + output_w * xh) * channel + ic] += delta;
out[(xw + in_w * xh) * channel + ic] += delta;
}
}
}
@ -66,21 +71,22 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter
} else { // nchw
for (uint16_t ic = 0; ic < channel; ic++) {
// iterate over yt
for (uint16_t yh = 0; yh < in_h; yh++) {
for (uint16_t yw = 0; yw < in_w; yw++) {
int idx = (ic * in_h * in_w) + (in_w * yh) + yw;
for (uint16_t yh = 0; yh < output_h; yh++) {
for (uint16_t yw = 0; yw < output_w; yw++) {
int idx = (ic * output_h * output_w) + (output_w * yh) + yw;
float delta = inPtr[idx] / kk;
for (int32_t kh = 0; kh < win_h; kh++) {
int xh = yh * stride_h + kh - pad_h;
if ((xh < 0) || (xh >= output_h)) {
if ((xh < 0) || (xh >= in_h)) {
continue;
}
for (int32_t kw = 0; kw < win_w; kw++) {
int xw = yw * stride_w + kw - pad_w;
if ((xw < 0) || (xw >= output_w)) {
if ((xw < 0) || (xw >= in_w)) {
continue;
}
out[(ic * output_h * output_w) + (xh * output_w) + xw] += delta;
// out[(ic * output_h * output_w) + (xh * output_w) + xw] += delta;
out[(ic * in_h * in_w) + (xh * in_w) + xw] += delta;
}
}
}
@ -90,7 +96,14 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter
}
}
void MaxPoolingGrad(const float *dy, const int *indices, float *output_ptr, PoolingParameter *pooling_param) {
void MaxPoolingGrad(const float *input_ptr, const float *dx_ptr, const float *dy_ptr, float *output_ptr,
PoolingParameter *pooling_param) {
int stride_w = pooling_param->stride_w_;
int stride_h = pooling_param->stride_h_;
int pad_w = pooling_param->pad_l_;
int pad_h = pooling_param->pad_u_;
int win_w = pooling_param->window_w_;
int win_h = pooling_param->window_h_;
int channel = pooling_param->input_channel_;
int in_w = pooling_param->input_w_;
int in_h = pooling_param->input_h_;
@ -98,38 +111,73 @@ void MaxPoolingGrad(const float *dy, const int *indices, float *output_ptr, Pool
int output_h = pooling_param->output_h_;
int output_batch = pooling_param->output_batch_;
int out_img_size =
output_h * output_w; // Emir -- in original code this varible is calculated according to input size ??
int ind_img_size = in_h * in_w;
// const int w_pad = (output_w + pad_w + pad_w);
const float *inPtr;
const float *dyPtr;
for (int i = 0; i < in_h * in_w * channel * output_batch; i++) output_ptr[i] = 0.0;
for (uint16_t ib = 0; ib < output_batch; ib++) {
float *out;
out = &output_ptr[(ib * in_h * in_w * channel)];
inPtr = (const float *)(&input_ptr[(ib * in_h * in_w * channel)]);
dyPtr = (const float *)(&dy_ptr[(ib * output_h * output_w * channel)]);
for (int i = 0; i < output_h * output_w * channel * output_batch; i++) output_ptr[i] = 0.0;
if (1) { // nhwc
for (uint16_t yh = 0; yh < output_h; yh++) {
for (uint16_t yw = 0; yw < output_w; yw++) {
for (uint16_t ic = 0; ic < channel; ic++) {
int idx = (yw + yh * output_w) * channel + ic;
const float *yt = (const float *)(dy);
const int *pos = (const int *)(indices);
float *out = NULL;
float delta = dyPtr[idx];
float max_val = -FLT_MAX;
int max_idx = 0;
for (int32_t kh = 0; kh < win_h; kh++) {
int xh = yh * stride_h + kh - pad_h;
if ((xh < 0) || (xh >= in_h)) {
continue;
}
for (int32_t kw = 0; kw < win_w; kw++) {
int xw = yw * stride_w + kw - pad_w;
if ((xw < 0) || (xw >= in_w)) {
continue;
}
if (1) { // grads->layout() == Tensor::nhwc)
for (int ib = 0; ib < output_batch; ib++) {
out = &(output_ptr[ib * output_w * output_w * channel]);
for (int ix = 0; ix < ind_img_size; ix++) {
for (int cix = 0; cix < channel; cix++) {
int idx = (*pos) * channel + cix;
out[idx] += *yt;
pos++;
yt++;
if (inPtr[(xw + in_w * xh) * channel + ic] > max_val) {
max_val = inPtr[(xw + in_w * xh) * channel + ic];
max_idx = (xw + in_w * xh) * channel + ic;
}
}
}
out[max_idx] += delta;
}
}
}
}
} else {
for (int ib = 0; ib < output_batch; ib++) {
out = &output_ptr[(ib * out_img_size)];
for (int cix = 0; cix < channel; cix++) {
for (int ix = 0; ix < ind_img_size; ix++) {
int idx = cix * output_h * output_w + *pos; // cord_y*output_w + cord_x;
out[idx] += *yt;
pos++;
yt++;
} else { // nchw
for (uint16_t yh = 0; yh < output_h; yh++) {
for (uint16_t yw = 0; yw < output_w; yw++) {
for (uint16_t ic = 0; ic < channel; ic++) {
int idx = (ic * output_h * output_w) + (output_w * yh) + yw;
float delta = dyPtr[idx];
float max_val = -FLT_MAX;
int max_idx = 0;
for (int32_t kh = 0; kh < win_h; kh++) {
int xh = yh * stride_h + kh - pad_h;
if ((xh < 0) || (xh >= in_h)) {
continue;
}
for (int32_t kw = 0; kw < win_w; kw++) {
int xw = yw * stride_w + kw - pad_w;
if ((xw < 0) || (xw >= in_w)) {
continue;
}
if (inPtr[(ic * in_h * in_w) + (xh * in_w) + xw] > max_val) {
max_val = inPtr[(ic * in_h * in_w) + (xh * in_w) + xw];
max_idx = (ic * in_h * in_w) + (xh * in_w) + xw;
}
}
}
out[max_idx] += delta;
}
}
}
}

@ -23,7 +23,9 @@
extern "C" {
#endif
void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param);
void MaxPoolingGrad(const float *dy, const int *indices_ptr, float *output_ptr, PoolingParameter *pooling_param);
// void MaxPoolingGrad(const float *dy, const int *indices_ptr, float *output_ptr, PoolingParameter *pooling_param);
void MaxPoolingGrad(const float *input_ptr, const float *dx_ptr, const float *dy_ptr, float *output_ptr,
PoolingParameter *pooling_param);
#ifdef __cplusplus
}
#endif

@ -13,10 +13,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string.h>
#include "nnacl/fp32_grad/reduce_grad.h"
static inline bool NextIndex(const int num_dims, const int *dims, int *current) {
static inline int NextIndex(const int num_dims, const int *dims, int *current) {
int carry = 1;
for (int idx = num_dims - 1; idx >= 0; --idx) {
int current_val = current[idx] + carry;
@ -45,10 +45,10 @@ static inline size_t GetOutputOffset(const int num_dims, const int *dims, const
size_t offset = 0;
for (int idx = 0; idx < num_dims; ++idx) {
// if we need to skip this axis
bool is_axis = false;
int is_axis = 0;
for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
if (idx == axes[axis_idx]) {
is_axis = true;
is_axis = 1;
break;
}
}
@ -101,10 +101,10 @@ float ReduceMeanAll(const float *src, int size) {
void ReduceSumByAxes(const float *input, const int *input_dims, float *output, const int *output_dims, int num_dims) {
int num_outputs = 1;
int same_shape = true;
int same_shape = 1;
for (int idx = 0; idx < num_dims; ++idx) {
num_outputs *= output_dims[idx];
if (output_dims[idx] != input_dims[idx]) same_shape = false;
if (output_dims[idx] != input_dims[idx]) same_shape = 0;
}
if (same_shape) {
memcpy(output, input, num_outputs * sizeof(float));

@ -17,8 +17,7 @@
#ifndef MINDSPORE_LITE_NNACL_FP32_REDUCE_GRAD_H_
#define MINDSPORE_LITE_NNACL_FP32_REDUCE_GRAD_H_
#include <cstddef.h>
#include <algorithm.h>
#include <stddef.h>
#ifdef __cplusplus
extern "C" {

@ -20,7 +20,7 @@
#include "nnacl/op_base.h"
typedef struct SoftmaxCrossEntropyParameter {
OpParameter op_parameter;
OpParameter op_parameter_;
int32_t batch_size_;
unsigned int number_of_classes_;
int n_dim_;

@ -178,8 +178,8 @@ union PrimitiveType {
Conv2DGradFilter,
Conv2DGradInput,
PoolingGrad,
BNGradInput,
OptMomentum,
BNGrad,
ApplyMomentum,
BiasGrad,
SoftmaxCrossEntropy,
AddGrad,
@ -190,6 +190,7 @@ union PrimitiveType {
ActivationGrad,
PriorBox,
SpaceToBatchND,
Depend,
Return,
MakeTuple,
ToFormat,

@ -149,7 +149,8 @@ table Activation {
alpha: float = 0.2;
}
table ActivationGrad {
type: ActivationGradType = 0;
type: ActivationType = 0;
alpha: float = 0.2;
}
@ -230,6 +231,9 @@ table SoftmaxCrossEntropy {
axis: [int];
}
table make_tuple {
}
table PoolingGrad {
format: Format = 0;
@ -390,10 +394,11 @@ table DeConv2D {
hasBias: bool = false;
activationType: ActivationType = 0;
}
table BNGradInput {
table BNGrad {
eps : float;
channels: int;
momentum: float;
}
table Scale {
axis: int;
}
@ -841,7 +846,10 @@ table SquaredDifference {
table TupleGetItem {
}
table OptMomentum {
table ApplyMomentum {
gradientScale: float;
useLocking: bool;
useNesterov: bool;
}
@ -884,6 +892,10 @@ table ToFormat {
dstT: int;
}
table Depend {
}
table Return {
}

@ -27,7 +27,7 @@ set(LITE_SRC
)
if (SUPPORT_GPU)
set(LITE_SRC
set(LITE_SRC
${LITE_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/subgraph_opencl_kernel.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/kernel/opencl/utils.cc
@ -36,6 +36,24 @@ if (SUPPORT_GPU)
${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_runtime.cc
${CMAKE_CURRENT_SOURCE_DIR}/runtime/opencl/opencl_wrapper.cc
)
endif()
if (SUPPORT_TRAIN)
set(ANF_SRC
${ANF_SRC}
)
set(PASS_SRC)
set(LITE_SRC
${LITE_SRC}
${ANF_SRC}
# ${CMAKE_CURRENT_SOURCE_DIR}/train/ops/train_ops.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/train_populate_parameter.cc
${CMAKE_CURRENT_SOURCE_DIR}/train/train_session.cc
${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc
)
endif ()
file(GLOB_RECURSE C_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/ops/*.cc)

@ -110,6 +110,7 @@ int CompareOutputData(float *output_data, float *correct_data, int data_size) {
}
}
error /= data_size;
if (error > 0.0001) {
printf("has accuracy error!\n");
printf("%f\n", error);
@ -118,12 +119,14 @@ int CompareOutputData(float *output_data, float *correct_data, int data_size) {
return 0;
}
void CompareOutput(float *output_data, std::string file_path) {
int CompareOutput(float *output_data, std::string file_path) {
size_t output_size;
auto ground_truth = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &output_size));
size_t output_num = output_size / sizeof(float);
printf("output num : %zu\n", output_num);
CompareOutputData(output_data, ground_truth, output_num);
int res = CompareOutputData(output_data, ground_truth, output_num);
delete [] ground_truth;
return res;
}
} // namespace lite
} // namespace mindspore

@ -47,7 +47,7 @@ void WriteToTxt(const std::string& file_path, void *data, size_t element_size) {
int WriteToBin(const std::string& file_path, void *data, size_t size);
int CompareOutputData(float *output_data, float *correct_data, int data_size);
void CompareOutput(float *output_data, std::string file_path);
int CompareOutput(float *output_data, std::string file_path);
std::string GetAndroidPackageName();
std::string GetAndroidPackagePath();

@ -47,7 +47,9 @@ int CompareRelativeOutput(float *output_data, std::string file_path) {
auto ground_truth = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &output_size));
size_t output_num = output_size / sizeof(float);
std::cout << "output num : " << output_num << "\n";
return CompareOutputRelativeData(output_data, ground_truth, output_num);
int res = CompareOutputRelativeData(output_data, ground_truth, output_num);
delete [] ground_truth;
return res;
}
} // namespace lite
} // namespace mindspore

@ -39,6 +39,10 @@ int Executor::Run(std::vector<tensor::Tensor *> &in_tensors, std::vector<tensor:
}
}
kernel::LiteKernelUtil::InitTensorRefCount(kernels);
for (auto out_tensor : out_tensors) { // increase RefCount of output tensors, such that Run will not free them
out_tensor->SetRefCount(out_tensor->RefCount() + 1);
}
for (auto *kernel : kernels) {
MS_ASSERT(nullptr != kernel);
@ -48,6 +52,8 @@ int Executor::Run(std::vector<tensor::Tensor *> &in_tensors, std::vector<tensor:
MS_LOG(ERROR) << "run kernel before_callback failed, name: " << kernel->name();
}
}
// JBDEBUG
// std::cout << "executing kernel " << kernel->name() << "\n";
auto ret = kernel->Run();
if (0 != ret) {
MS_LOG(ERROR) << "run kernel failed, name: " << kernel->name();

@ -27,7 +27,6 @@
#include "src/ir/tensor.h"
#include "include/errorcode.h"
// using mindspore::kernel::AddressPtr;
namespace mindspore::kernel {
using mindspore::lite::RET_ERROR;

@ -112,11 +112,11 @@ int ModelImpl::BuildOps() {
Model *Model::Import(const char *model_buf, size_t size) {
auto model = new Model();
model->model_impl_ = ModelImpl::Import(model_buf, size);
if (model_buf == nullptr) {
MS_LOG(ERROR) << "model buf is null";
return nullptr;
}
model->model_impl_ = ModelImpl::Import(model_buf, size);
if (model->model_impl_ == nullptr) {
MS_LOG(ERROR) << "model impl is null";
return nullptr;

@ -20,11 +20,11 @@ namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int ActivationGrad::GetType() const { return this->primitive_->value.AsActivationGrad()->type; }
float ActivationGrad::GetAlpha() const { return this->primitive_->value.AsActivationGrad()->alpha; }
void ActivationGrad::SetType(int type) {
this->primitive_->value.AsActivationGrad()->type = (schema::ActivationGradType)type;
this->primitive_->value.AsActivationGrad()->type = (schema::ActivationType)type;
}
void ActivationGrad::SetAlpha(float alpha) { this->primitive_->value.AsActivationGrad()->alpha = alpha; }
#else
int ActivationGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
@ -40,7 +40,7 @@ int ActivationGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flat
return RET_OK;
}
int ActivationGrad::GetType() const { return this->primitive_->value_as_ActivationGrad()->type(); }
float ActivationGrad::GetAlpha() const { return this->primitive_->value_as_ActivationGrad()->alpha(); }
#endif
} // namespace lite
} // namespace mindspore

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_GRAD_H_
#define LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_GRAD_H_
#ifndef MINDSPORE_LITE_SRC_OPS_ACTIVATION_GRAD_H_
#define MINDSPORE_LITE_SRC_OPS_ACTIVATION_GRAD_H_
#include <vector>
#include <set>
@ -32,13 +32,15 @@ class ActivationGrad : public PrimitiveC {
ActivationGrad() = default;
explicit ActivationGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetType(int type);
void SetAlpha(float alpha);
#else
ActivationGrad() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int GetType() const;
float GetAlpha() const;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_GRAD_H_
#endif // MINDSPORE_LITE_SRC_OPS_ACTIVATION_GRAD_H_

@ -0,0 +1,64 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/apply_momentum.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
#else
int ApplyMomentum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_ApplyMomentum();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_ApplyMomentum return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateApplyMomentum(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ActivationGrad, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
#endif
int ApplyMomentum::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) {
if (5 != inputs.size()) {
MS_LOG(ERROR) << "ApplyMomentum should have at 5 input tensors";
return RET_ERROR;
}
// if (outputs.empty()) {
// MS_LOG(ERROR) << "ApplyMomentumCPUKernel error input output size!";
// return RET_ERROR;
// }
if (inputs[0]->ElementsNum() != inputs[1]->ElementsNum() || inputs[0]->ElementsNum() != inputs[3]->ElementsNum() ||
inputs[2]->ElementsNum() != 1 || inputs[4]->ElementsNum() != 1) {
MS_LOG(ERROR) << "error input data size!";
return RET_ERROR;
}
if (!outputs.empty()) {
auto *out = outputs.front();
MS_ASSERT(out != nullptr);
out->set_data_type(inputs[0]->data_type());
out->SetFormat(inputs[0]->GetFormat());
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -0,0 +1,44 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_OPS_APPLY_MOMENTUM_H_
#define MINDSPORE_LITE_SRC_OPS_APPLY_MOMENTUM_H_
#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class ApplyMomentum : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(ApplyMomentum, PrimitiveC);
ApplyMomentum() = default;
explicit ApplyMomentum(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
ApplyMomentum() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_OPS_APPLY_MOMENTUM_H_

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save