!6132 train on device

Merge pull request !6132 from yonibaehr/export
pull/6132/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 2f14c40934

@ -27,7 +27,6 @@ struct Model;
}
namespace session {
class TrainSession : public lite::LiteSession {
public:
TrainSession();

@ -20,8 +20,8 @@
#include "nnacl/op_base.h"
#include "nnacl/errorcode.h"
void BatchNormFp32(const void *input, const void *mean, const void *variance,
BatchNormParameter *param, int task_id, void *output) {
void BatchNormFp32(const void *input, const void *mean, const void *variance, BatchNormParameter *param, int task_id,
void *output) {
int units_per_thread = UP_DIV(param->unit_, param->op_parameter_.thread_num_);
int completed_units = task_id * units_per_thread;
int cur_unit = MSMIN(units_per_thread, param->unit_ - completed_units);
@ -31,7 +31,7 @@ void BatchNormFp32(const void *input, const void *mean, const void *variance,
for (int c = 0; c < param->channel_; c++) {
float variance_sqrt = sqrt(((const float *)variance)[c] + param->epsilon_);
((float *)output)[cur_offset + c] =
(((const float *)input)[cur_offset + c] - ((const float *)mean)[c]) / variance_sqrt;
(((const float *)input)[cur_offset + c] - ((const float *)mean)[c]) / variance_sqrt;
}
cur_offset += param->channel_;
}
@ -53,3 +53,22 @@ void FusedBatchNormFp32(const void *input, const void *scale, const void *offset
cur_offset += param->channel_;
}
}
void FusedBatchNormFp32MeanVar(const float *input, float momentum, float *run_mean, float *run_var,
BatchNormParameter *param, float *save_mean, float *save_inv_var) {
float N = param->channel_ * param->unit_;
for (int i = 0; i < param->unit_; i++) {
for (int f = 0; f < param->channel_; f++) {
int idx = i * param->channel_ + f;
run_mean[f] += input[idx];
run_var[f] += input[idx] * input[idx];
}
}
for (int f = 0; f < param->channel_; f++) {
run_mean[f] = run_mean[f] / N;
run_var[f] = run_var[f] / N - run_mean[f] * run_mean[f];
save_mean[f] = momentum * save_mean[f] + (1 - momentum) * run_mean[f];
float inv_var = 1.f/sqrt(run_var[f]+param->epsilon_);
save_inv_var[f] = momentum * save_inv_var[f] + (1 - momentum) * inv_var;
}
}

@ -28,6 +28,8 @@ void BatchNormFp32(const void *input, const void *mean, const void *variance, Ba
void FusedBatchNormFp32(const void *input, const void *scale, const void *offset, const void *mean,
const void *variance, BatchNormParameter *param, int task_id, void *output);
void FusedBatchNormFp32MeanVar(const float *input, float momentum, float *run_mean, float *run_var,
BatchNormParameter *param, float *save_mean, float *save_var);
#ifdef __cplusplus
}
#endif

@ -27,41 +27,8 @@ void sumSpatialBatch(const float *in, int size, int ch, float *out) {
}
}
void scaleBias(const float *scales, int batch, int n, int size, float *output) {
for (int i = 0; i < batch * size; i++)
for (int c = 0; c < n; c++) output[i * n + c] *= scales[c];
}
void normalize(const float *x, const float *mean, const float *invar, int batch, int filters, int spatial,
float *out) {
int b, f, i;
for (b = 0; b < batch; ++b) {
for (i = 0; i < spatial; ++i) {
for (f = 0; f < filters; ++f) {
int index = b * filters * spatial + i * filters + f;
out[index] = (x[index] - mean[f]) * invar[f];
}
}
}
}
void backwardScale(const float *x, const float *mean, const float *invar, const float *delta, int batch,
int n, int size, float *scale_updates) {
int i, b, f;
memset(scale_updates, 0, n * sizeof(float));
for (b = 0; b < batch; ++b) {
for (i = 0; i < size; ++i) {
for (f = 0; f < n; ++f) {
int index = (b * size + i) * n + f;
float x_norm = (x[index] - mean[f]) * invar[f];
scale_updates[f] += delta[index] * x_norm;
}
}
}
}
void meanVar(const float *in, int batch, int spatial, int ch, float eps, float *mean, float *invar) {
float N = batch * spatial;
static void meanVar(const float *in, int size, int ch, float eps, float *mean, float *invar) {
float N = (float)size;
sumSpatialBatch(in, N, ch, mean);
for (int f = 0; f < ch; ++f) {
mean[f] /= N;
@ -76,49 +43,40 @@ void meanVar(const float *in, int batch, int spatial, int ch, float eps, float *
}
}
void meanDelta(float *yt, int size, int ch, float *invar, float *mean_delta) {
sumSpatialBatch(yt, size, ch, mean_delta);
for (int i = 0; i < ch; i++) mean_delta[i] *= -invar[i];
}
void meanAdd(const float *x, const float *mean, const float *variance_delta, int batch, int filters, int spatial,
float *mean_add, float *mean_delta) {
int i, k;
memset(mean_add, 0, filters * sizeof(float));
for (k = 0; k < spatial * batch; ++k) {
for (i = 0; i < filters; ++i) {
int index = k * filters + i;
mean_add[i] += x[index] - mean[i];
void backwardX(const float *in, const float *dout, const float *scale, const int size, int channels, float eps,
float *mean, float *invar, float *dxhathat_sum, float *dxhat_sum, float *out) {
meanVar(in, size, channels, eps, mean, invar);
for (int i = 0; i < size; i++) {
for (int f = 0; f < channels; f++) {
int ix = i*channels + f;
float x_hat = (in[ix] - mean[f]) * invar[f];
float dxhat = dout[ix] * scale[f];
dxhat_sum[f] += dxhat;
dxhathat_sum[f] += dxhat * x_hat;
}
}
for (i = 0; i < filters; ++i) {
mean_add[i] *= variance_delta[i] * (-2.f / (spatial * batch));
mean_delta[i] += mean_add[i];
}
}
void varianceDelta(const float *x, const float *delta, const float *mean, const float *invar, int batch, int filters,
int spatial, float *variance_delta) {
int i, k;
memset(variance_delta, 0, filters * sizeof(float));
for (k = 0; k < batch * spatial; k++) {
for (i = 0; i < filters; i++) {
int index = k * filters + i;
variance_delta[i] += delta[index] * (x[index] - mean[i]);
for (int i = 0; i < size; i++) {
for (int f = 0; f < channels; f++) {
int ix = i*channels + f;
float x_hat = (in[ix] - mean[f]) * invar[f];
float dxhat = dout[ix] * scale[f];
out[ix] = 1.f / size * invar[f] * (size * dxhat - dxhat_sum[f] - x_hat * dxhathat_sum[f]);
}
}
for (i = 0; i < filters; i++) variance_delta[i] *= -.5 * 1.0f/(invar[i]*invar[i]*invar[i]);
}
void NormalizeDelta(const float *x, const float *mean, const float *invar, const float *mean_delta,
const float *variance_delta, int batch, int filters, int spatial, float *delta) {
int f, k;
for (k = 0; k < batch * spatial; k++) {
for (f = 0; f < filters; f++) {
int index = k * filters + f;
delta[index] = delta[index] * invar[f] +
variance_delta[f] * 2. * (x[index] - mean[f]) / (spatial * batch) +
mean_delta[f] / (spatial * batch);
void backwardScale(const float *x, const float *mean, const float *invar, const float *delta, int batch,
int n, int size, float *scale_updates) {
int i, b, f;
memset(scale_updates, 0, n * sizeof(float));
for (b = 0; b < batch; ++b) {
for (i = 0; i < size; ++i) {
for (f = 0; f < n; ++f) {
int index = (b * size + i) * n + f;
float x_norm = (x[index] - mean[f]) * invar[f];
scale_updates[f] += delta[index] * x_norm;
}
}
}
}

@ -30,18 +30,11 @@ extern "C" {
#endif
void sumSpatialBatch(const float *in, int size, int ch, float *out);
void scaleBias(const float *scales, int batch, int n, int size, float *output);
void normalize(const float *x, const float *mean, const float *invar, int batch, int filters, int spatial, float *out);
void backwardScale(const float *x, const float *mean, const float *invar, const float *delta, int batch, int n,
int size, float *scale_updates);
void meanVar(const float *in, int batch, int size, int ch, float eps, float *mean, float *invar);
void meanDelta(float *yt, int size, int ch, float *invar, float *mean_delta);
void varianceDelta(const float *x, const float *delta, const float *mean, const float *invar, int batch, int ch,
int spatial, float *variance_delta);
void meanAdd(const float *x, const float *mean, const float *variance_delta, int batch, int filters, int spatial,
float *mean_add, float *mean_delta);
void NormalizeDelta(const float *x, const float *mean, const float *invar, const float *mean_delta,
const float *variance_delta, int batch, int filters, int spatial, float *delta);
void backwardX(const float *in, const float *dout, const float *scale, const int size, int channels, float eps,
float *mean, float *invar, float *xhat_sum, float *dxhat_sum, float *out);
void backwardScale(const float *x, const float *mean, const float *invar, const float *delta, int batch,
int n, int size, float *scale_updates);
#ifdef __cplusplus
}
#endif

@ -0,0 +1,61 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp32_grad/softmax_grad.h"
#include <string.h>
#include "nnacl/fp32_grad/gemm.h"
void SoftmaxGrad(const float *input_ptr, const float *yt_ptr, float *output_ptr, float *sum_data, float *sum_mul,
SoftmaxParameter *parameter) {
int32_t axis = parameter->axis_;
int n_dim = parameter->n_dim_;
int ele_size = parameter->element_size_;
int *input_shape = parameter->input_shape_;
int dim = 1;
int inner_size = 1, outter_size = 1;
for (int i = 0; i < axis; i++) {
outter_size *= input_shape[i];
}
for (int i = axis + 1; i < n_dim; i++) {
inner_size *= input_shape[i];
}
for (int i = 0; i < inner_size * input_shape[axis]; i++) sum_mul[i] = 1.0;
for (int i = 0; i < n_dim; i++) dim *= input_shape[i];
dim /= outter_size;
memcpy(output_ptr, yt_ptr, ele_size * sizeof(float));
int M = input_shape[axis];
int N = inner_size;
int K = 1;
for (int i = 0; i < outter_size; i++) {
int outter_offset = i * dim;
memset(sum_data, 0.0f, inner_size * sizeof(float));
for (int k = 0; k < inner_size; k++) {
int inner_offset = outter_offset + k;
for (int j = 0; j < input_shape[axis]; j++) {
int offset = inner_offset + j * inner_size;
sum_data[k] += output_ptr[offset] * input_ptr[offset];
}
}
gemm(0, 0, M, N, K, -1, sum_mul, K, sum_data, N, 1, &output_ptr[outter_offset], N);
}
for (int i = 0; i < ele_size; i++) {
output_ptr[i] *= input_ptr[i];
}
}

@ -14,10 +14,15 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP32_SOFTMAX_GRAD_H_
#define MINDSPORE_LITE_NNACL_FP32_SOFTMAX_GRAD_H_
#ifndef MINDSPORE_LITE_NNACL_FP32_GRAD_SOFTMAX_GRAD_H_
#define MINDSPORE_LITE_NNACL_FP32_GRAD_SOFTMAX_GRAD_H_
#include "nnacl/op_base.h"
#include "nnacl/fp32/softmax.h"
#ifdef __cplusplus
extern "C" {
#endif
typedef struct SoftmaxCrossEntropyParameter {
OpParameter op_parameter_;
@ -26,4 +31,11 @@ typedef struct SoftmaxCrossEntropyParameter {
int n_dim_;
int input_shape_[5];
} SoftmaxCrossEntropyParameter;
#endif // MINDSPORE_LITE_NNACL_FP32_SOFTMAX_GRAD_H_
void SoftmaxGrad(const float *input_ptr, const float *yt_ptr, float *output_ptr, float *sum_data,
float *sum_mul, SoftmaxParameter *parameter);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_FP32_GRAD_SOFTMAX_GRAD_H_

@ -21,7 +21,7 @@
namespace mindspore {
namespace lite {
static int CompareOutputRelativeData(float *output_data, float *correct_data, int data_size) {
static float CompareOutputRelativeData(float *output_data, float *correct_data, int data_size) {
float error = 0;
// relative error
@ -35,6 +35,16 @@ static int CompareOutputRelativeData(float *output_data, float *correct_data, in
diffSum += diff;
}
error = diffSum / sum;
return error;
}
int CompareRelativeOutput(float *output_data, std::string file_path) {
size_t output_size;
auto ground_truth = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &output_size));
size_t output_num = output_size / sizeof(float);
// std::cout << "output num : " << output_num << "\n";
int error = CompareOutputRelativeData(output_data, ground_truth, output_num);
delete [] ground_truth;
if (error > 1e-4) {
std::cout << "has accuracy error!\n" << error << "\n";
return 1;
@ -42,14 +52,15 @@ static int CompareOutputRelativeData(float *output_data, float *correct_data, in
return 0;
}
int CompareRelativeOutput(float *output_data, std::string file_path) {
float RelativeOutputError(float *output_data, std::string file_path) {
size_t output_size;
auto ground_truth = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &output_size));
size_t output_num = output_size / sizeof(float);
std::cout << "output num : " << output_num << "\n";
int res = CompareOutputRelativeData(output_data, ground_truth, output_num);
delete[] ground_truth;
return res;
float error = CompareOutputRelativeData(output_data, ground_truth, output_num);
delete [] ground_truth;
return error;
}
} // namespace lite
} // namespace mindspore

@ -21,6 +21,7 @@
namespace mindspore {
namespace lite {
int CompareRelativeOutput(float *output_data, std::string file_path);
float RelativeOutputError(float *output_data, std::string file_path);
}
} // namespace mindspore
#endif // MINDSPORE_LITE_COMMON_FILE_UTILS_EXT_H_

@ -32,13 +32,16 @@
namespace mindspore {
namespace lite {
static std::vector<schema::PrimitiveType> packed_op = {
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D, schema::PrimitiveType_DepthwiseConv2D,
schema::PrimitiveType_DeDepthwiseConv2D, schema::PrimitiveType_MatMul};
// this method will not check whether tensor_idx is a weight tensor index, caller should ensure this.
static bool WeightTensorNeedCopy(const lite::Model *model, const uint32_t tensor_idx) {
#ifdef SUPPORT_TRAIN
return false;
#endif
MS_ASSERT(model != nullptr);
auto post_node_idxes = GetLinkedPostNodeIdx(model, tensor_idx);
return std::none_of(post_node_idxes.begin(), post_node_idxes.end(), [&](const size_t &post_node_idx) {
@ -267,7 +270,9 @@ int LiteSession::CompileGraph(Model *model) {
}
executor->Prepare(this->kernels_);
#ifndef SUPPORT_TRAIN
model->Free();
#endif
return RET_OK;
}

@ -42,9 +42,11 @@ bool ConvertNodes(const schema::MetaGraph *meta_graph, Model *model) {
for (uint32_t j = 0; j < count; ++j) {
node->input_indices_.push_back(size_t(c_node->inputIndex()->GetAs<uint32_t>(j)));
}
count = c_node->outputIndex()->size();
for (uint32_t j = 0; j < count; ++j) {
node->output_indices_.push_back(size_t(c_node->outputIndex()->GetAs<uint32_t>(j)));
if (c_node->outputIndex() != nullptr) {
count = c_node->outputIndex()->size();
for (uint32_t j = 0; j < count; ++j) {
node->output_indices_.push_back(size_t(c_node->outputIndex()->GetAs<uint32_t>(j)));
}
}
model->nodes_.push_back(node);
}

@ -46,6 +46,8 @@ int ActivationGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodeP
} else if (prim.name() == "ReLU6") {
attr->type = schema::ActivationType_RELU6;
}
// auto alpha = GetValue<float>(prim.GetAttr("alpha"));
attr->alpha = 0; // alpha;
this->primitive_->value.value = attr.release();
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";

@ -16,7 +16,6 @@
#include "src/ops/apply_momentum.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int ApplyMomentum::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
@ -31,11 +30,17 @@ int ApplyMomentum::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePt
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
auto attr = std::make_unique<schema::ApplyMomentumT>();
this->primitive_->value.value = attr.release();
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
auto attr = std::make_unique<schema::ApplyMomentumT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
this->primitive_->value.value = attr.release();
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
}
return RET_OK;
}
@ -49,13 +54,13 @@ int ApplyMomentum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb
return RET_ERROR;
}
auto val_offset = schema::CreateApplyMomentum(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ActivationGrad, val_offset.o);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ApplyMomentum, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
#endif
int ApplyMomentum::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) {
int ApplyMomentum::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
if (5 != inputs.size()) {
MS_LOG(ERROR) << "ApplyMomentum should have at 5 input tensors";
return RET_ERROR;

@ -48,6 +48,9 @@ int ArithmeticGrad::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<
if ((Type() == schema::PrimitiveType_AddGrad) || (Type() == schema::PrimitiveType_SubGrad)) {
ndim_ = outShape.size();
x1_shape_.resize(ndim_);
x2_shape_.resize(ndim_);
dy_shape_.resize(ndim_);
auto fillDimNum0 = outShape.size() - inShape0.size();
auto fillDimNum1 = outShape.size() - inShape1.size();
int j0 = 0;
@ -61,6 +64,9 @@ int ArithmeticGrad::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<
// if (inShape0.size() < inShape1.size())
if (dx1->ElementsNum() < dx2->ElementsNum()) {
ndim_ = inShape1.size();
x1_shape_.resize(ndim_);
x2_shape_.resize(ndim_);
dy_shape_.resize(ndim_);
auto fillDimNum = inShape1.size() - inShape0.size(); // This will not work for batch!
int j = 0;
for (unsigned int i = 0; i < inShape1.size(); i++) {
@ -74,8 +80,10 @@ int ArithmeticGrad::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<
}
} else if (dx2->ElementsNum() < dx1->ElementsNum()) { // if (inShape0.size() > inShape1.size())
ndim_ = inShape0.size();
x1_shape_.resize(ndim_);
x2_shape_.resize(ndim_);
dy_shape_.resize(ndim_);
broadcasting_ = true;
ndim_ = inShape0.size();
int j = 0;
auto fillDimNum = inShape0.size() - inShape1.size();
for (unsigned int i = 0; i < inShape0.size(); i++) {

@ -32,7 +32,7 @@ class ArithmeticGrad : public PrimitiveC {
ArithmeticGrad() = default;
explicit ArithmeticGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
// explicit Arithmetic(schema::Primitive *primitive) : PrimitiveC(primitive) {}
// explicit ArithmeticGrad(const schema::Primitive &primitive) : PrimitiveC(primitive) {}
ArithmeticGrad() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override {
return RET_ERROR;

@ -41,6 +41,7 @@ int BiasGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &i
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
attr->axis = {0}; // GetValue<std::vector<int>>(prim.GetAttr("axis"));
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
@ -73,6 +74,7 @@ std::vector<int> BiasGrad::GetAxis() const {
auto fb_vector = this->primitive_->value_as_BiasGrad()->axis();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
#endif
int BiasGrad::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) {
if (1 != inputs.size()) {
@ -99,6 +101,5 @@ int BiasGrad::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> out
return RET_OK;
}
#endif
} // namespace lite
} // namespace mindspore

@ -38,8 +38,8 @@ class BiasGrad : public PrimitiveC {
BiasGrad() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
int InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) override;
std::vector<int> GetAxis() const;
};
} // namespace lite

@ -67,9 +67,31 @@ int BNGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:
fbb->Finish(prim_offset);
return RET_OK;
}
float BNGrad::GetEps() const { return this->primitive_->value_as_BNGrad()->eps(); }
float BNGrad::GetMomentum() const { return this->primitive_->value_as_BNGrad()->momentum(); }
#endif
int BNGrad::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
if (5 != inputs.size()) {
MS_LOG(ERROR) << "BNGrad should have five inputs";
return RET_ERROR;
}
if (3 != outputs.size()) {
MS_LOG(ERROR) << "BNGrad should have three outputs";
return RET_ERROR;
}
auto in = inputs[1];
auto scale = inputs[2];
outputs[0]->set_shape(in->shape());
outputs[1]->set_shape(scale->shape());
outputs[2]->set_shape(scale->shape());
outputs[0]->set_data_type(in->data_type());
outputs[1]->set_data_type(scale->data_type());
outputs[2]->set_data_type(scale->data_type());
outputs[0]->SetFormat(in->GetFormat());
outputs[1]->SetFormat(scale->GetFormat());
outputs[2]->SetFormat(scale->GetFormat());
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -38,6 +38,8 @@ class BNGrad : public PrimitiveC {
BNGrad() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_,
std::vector<lite::Tensor *> outputs_) override;
float GetEps() const;
float GetMomentum() const;
};

@ -1,75 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/bn_grad_input.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
float BNGradInput::GetEps() const { return this->primitive_->value.AsBNGradInput()->eps; }
float BNGradInput::GetMomentum() const { return this->primitive_->value.AsBNGradInput()->momentum; }
void BNGradInput::SetEps(float eps) { this->primitive_->value.AsBNGradInput()->eps = eps; }
void BNGradInput::SetMomentum(float momentum) { this->primitive_->value.AsBNGradInput()->momentum = momentum; }
int BNGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_BNGradInput;
}
if (this->primitive_->value.type != schema::PrimitiveType_BNGradInput) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::BNGradInputT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
attr->eps = GetValue<float>(prim.GetAttr("eps"));
attr->momentum = GetValue<float>(prim.GetAttr("momentum"));
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int BNGradInput::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_BNGradInput();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_BNGradInputInput return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateBNGradInput(*fbb, attr->eps(), attr->momentum());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BNGradInput, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
float BNGradInput::GetEps() const { return this->primitive_->value_as_BNGradInput()->eps(); }
float BNGradInput::GetMomentum() const { return this->primitive_->value_as_BNGradInput()->momentum(); }
#endif
} // namespace lite
} // namespace mindspore

@ -1,47 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_
#define LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_
#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class BNGradInput : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(BNGradInput, PrimitiveC);
BNGradInput() = default;
explicit BNGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetEps(float eps);
void SetMomentum(float momentum);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
BNGradInput() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
float GetEps() const;
float GetMomentum() const;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_

@ -66,108 +66,7 @@ void Conv2DGradFilter::SetHasBias(bool has_bias) { this->primitive_->value.AsCon
void Conv2DGradFilter::SetActivationType(int activation_type) {
this->primitive_->value.AsConv2DGradFilter()->activationType = (schema::ActivationType)activation_type;
}
void Conv2DGradFilter::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group,
const std::vector<AnfNodePtr> &inputs) {
auto attr = std::make_unique<schema::DepthwiseConv2DT>();
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
if (format == "NCHW") {
attr->format = schema::Format_NCHW;
} else if (format == "NHWC") {
attr->format = schema::Format_NHWC;
} else {
attr->format = schema::Format_NUM_OF_FORMAT;
}
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
attr->dilateH = dilation[0];
attr->dilateW = dilation[1];
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
attr->strideH = stride[2];
attr->strideW = stride[3];
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
if (pad_mode == "valid") {
attr->padMode = schema::PadMode_VALID;
} else if (pad_mode == "same") {
attr->padMode = schema::PadMode_SAME;
} else {
attr->padMode = schema::PadMode_NOTSET;
}
if (prim.GetAttr("activation_name") != nullptr) {
std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name"));
attr->activationType = kActivationTypeMap[activate_name];
} else {
attr->activationType = schema::ActivationType_NO_ACTIVATION;
}
int channel_mutiplier = 1;
if (prim.GetAttr("channel_mutiplier") != nullptr) {
channel_mutiplier = GetValue<int>(prim.GetAttr("channel_multiplier"));
}
attr->channelMultiplier = channel_mutiplier;
primitive->value.value = attr.release();
}
void Conv2DGradFilter::PopulaterConv2DSingleGroup(const Primitive &prim,
schema::PrimitiveT *primitive, const int &group) {
auto attr = std::make_unique<schema::Conv2DT>();
attr->group = group;
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
if (format == "NCHW") {
attr->format = schema::Format_NCHW;
} else if (format == "NHWC") {
attr->format = schema::Format_NHWC;
} else {
attr->format = schema::Format_NUM_OF_FORMAT;
}
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
attr->dilateH = dilation[0];
attr->dilateW = dilation[1];
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
attr->strideH = stride[2];
attr->strideW = stride[3];
attr->channelOut = GetValue<int>(prim.GetAttr("out_channel"));
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
if (pad_mode == "valid") {
attr->padMode = schema::PadMode_VALID;
} else if (pad_mode == "same") {
attr->padMode = schema::PadMode_SAME;
} else {
attr->padMode = schema::PadMode_NOTSET;
}
if (prim.GetAttr("activation_name") != nullptr) {
std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name"));
attr->activationType = kActivationTypeMap[activate_name];
} else {
attr->activationType = schema::ActivationType_NO_ACTIVATION;
}
primitive->value.value = attr.release();
}
int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
@ -181,11 +80,62 @@ int Conv2DGradFilter::UnPackAttr(const Primitive &prim, const std::vector<AnfNod
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
return RET_ERROR;
}
int group = GetValue<int>(prim.GetAttr("group"));
if (group > 1) {
PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs);
} else {
PopulaterConv2DSingleGroup(prim, this->primitive_, group);
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::Conv2DGradFilterT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
attr->group = GetValue<int>(prim.GetAttr("group"));
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
if (format == "NCHW") {
attr->format = schema::Format_NCHW;
} else if (format == "NHWC") {
attr->format = schema::Format_NHWC;
} else {
attr->format = schema::Format_NUM_OF_FORMAT;
}
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
attr->dilateH = dilation[0];
attr->dilateW = dilation[1];
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
attr->strideH = stride[0];
attr->strideW = stride[1];
attr->channelOut = GetValue<int>(prim.GetAttr("out_channel"));
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
if (pad_mode == "valid") {
attr->padMode = schema::PadMode_VALID;
} else if (pad_mode == "same") {
attr->padMode = schema::PadMode_SAME;
} else {
attr->padMode = schema::PadMode_NOTSET;
}
if (prim.GetAttr("activation_name") != nullptr) {
std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name"));
attr->activationType = kActivationTypeMap[activate_name];
} else {
attr->activationType = schema::ActivationType_NO_ACTIVATION;
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
@ -268,6 +218,5 @@ int Conv2DGradFilter::InferShape(std::vector<Tensor *> inputs, std::vector<Tenso
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -51,9 +51,6 @@ class Conv2DGradFilter : public PrimitiveC {
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
void PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group,
const std::vector<AnfNodePtr> &inputs);
void PopulaterConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group);
#else
Conv2DGradFilter() = default;

@ -64,108 +64,7 @@ void Conv2DGradInput::SetHasBias(bool has_bias) { this->primitive_->value.AsConv
void Conv2DGradInput::SetActivationType(int activation_type) {
this->primitive_->value.AsConv2DGradInput()->activationType = (schema::ActivationType)activation_type;
}
void Conv2DGradInput::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group,
const std::vector<AnfNodePtr> &inputs) {
auto attr = std::make_unique<schema::DepthwiseConv2DT>();
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
if (format == "NCHW") {
attr->format = schema::Format_NCHW;
} else if (format == "NHWC") {
attr->format = schema::Format_NHWC;
} else {
attr->format = schema::Format_NUM_OF_FORMAT;
}
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
attr->dilateH = dilation[0];
attr->dilateW = dilation[1];
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
attr->strideH = stride[2];
attr->strideW = stride[3];
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
if (pad_mode == "valid") {
attr->padMode = schema::PadMode_VALID;
} else if (pad_mode == "same") {
attr->padMode = schema::PadMode_SAME;
} else {
attr->padMode = schema::PadMode_NOTSET;
}
if (prim.GetAttr("activation_name") != nullptr) {
std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name"));
attr->activationType = kActivationTypeMap[activate_name];
} else {
attr->activationType = schema::ActivationType_NO_ACTIVATION;
}
int channel_mutiplier = 1;
if (prim.GetAttr("channel_mutiplier") != nullptr) {
channel_mutiplier = GetValue<int>(prim.GetAttr("channel_multiplier"));
}
attr->channelMultiplier = channel_mutiplier;
primitive->value.value = attr.release();
}
void Conv2DGradInput::PopulaterConv2DSingleGroup(const Primitive &prim,
schema::PrimitiveT *primitive, const int &group) {
auto attr = std::make_unique<schema::Conv2DT>();
attr->group = group;
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
if (format == "NCHW") {
attr->format = schema::Format_NCHW;
} else if (format == "NHWC") {
attr->format = schema::Format_NHWC;
} else {
attr->format = schema::Format_NUM_OF_FORMAT;
}
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
attr->dilateH = dilation[0];
attr->dilateW = dilation[1];
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
attr->strideH = stride[2];
attr->strideW = stride[3];
attr->channelOut = GetValue<int>(prim.GetAttr("out_channel"));
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
if (pad_mode == "valid") {
attr->padMode = schema::PadMode_VALID;
} else if (pad_mode == "same") {
attr->padMode = schema::PadMode_SAME;
} else {
attr->padMode = schema::PadMode_NOTSET;
}
if (prim.GetAttr("activation_name") != nullptr) {
std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name"));
attr->activationType = kActivationTypeMap[activate_name];
} else {
attr->activationType = schema::ActivationType_NO_ACTIVATION;
}
primitive->value.value = attr.release();
}
int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
@ -179,11 +78,63 @@ int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
return RET_ERROR;
}
int group = GetValue<int>(prim.GetAttr("group"));
if (group > 1) {
PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs);
} else {
PopulaterConv2DSingleGroup(prim, this->primitive_, group);
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::Conv2DGradInputT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
attr->group = GetValue<int>(prim.GetAttr("group"));
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
if (format == "NCHW") {
attr->format = schema::Format_NCHW;
} else if (format == "NHWC") {
attr->format = schema::Format_NHWC;
} else {
attr->format = schema::Format_NUM_OF_FORMAT;
}
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
attr->dilateH = dilation[0];
attr->dilateW = dilation[1];
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
attr->strideH = stride[0];
attr->strideW = stride[1];
attr->channelOut = GetValue<int>(prim.GetAttr("out_channel"));
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
if (pad_mode == "valid") {
attr->padMode = schema::PadMode_VALID;
} else if (pad_mode == "same") {
attr->padMode = schema::PadMode_SAME;
} else {
attr->padMode = schema::PadMode_NOTSET;
}
if (prim.GetAttr("activation_name") != nullptr) {
std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name"));
attr->activationType = kActivationTypeMap[activate_name];
} else {
attr->activationType = schema::ActivationType_NO_ACTIVATION;
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
@ -265,6 +216,5 @@ int Conv2DGradInput::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -51,9 +51,6 @@ class Conv2DGradInput : public PrimitiveC {
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
void PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group,
const std::vector<AnfNodePtr> &inputs);
void PopulaterConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group);
#else
Conv2DGradInput() = default;

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save