!7228 tod add networks and ops

Merge pull request !7228 from yonibaehr/export
pull/7228/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit b3d43aa616

@ -235,3 +235,4 @@ if (NOT WIN32)
endif ()
include(${TOP_DIR}/cmake/package_lite.cmake)

@ -17,37 +17,28 @@
#define MINDSPORE_LITE_INCLUDE_TRAIN_SESSION_H_
#include <vector>
#include <string>
#include <tuple>
#include <unordered_map>
#include "src/lite_session.h"
#include "include/lite_session.h"
#include "include/train_model.h"
namespace mindspore {
namespace lite {
struct TrainModel;
}
namespace session {
class TrainSession : public lite::LiteSession {
public:
TrainSession();
~TrainSession();
int RunGraph(const session::KernelCallBack &before = nullptr,
const session::KernelCallBack &after = nullptr) override;
int CompileGraph(lite::Model *model) override;
virtual void* ExportToBuf(char* buf, size_t* len) const;
class TrainSession : public session::LiteSession {
public:
virtual ~TrainSession() = default;
static TrainSession *CreateSession(lite::Context *context);
virtual void Train();
virtual int CompileTrainGraph(lite::TrainModel *model) = 0;
virtual void *ExportToBuf(char *buf, size_t *len) const = 0;
virtual void Train() = 0;
bool IsTrain() { return train_mode_ == true; }
virtual void Eval();
virtual void Eval() = 0;
bool IsEval() { return train_mode_ == false; }
protected:
virtual void ReplaceOps();
bool train_mode_ = false;
lite::TrainModel *model_ = nullptr;
std::unordered_map<std::string, std::vector<mindspore::tensor::MSTensor *>> orig_output_map_;
std::unordered_map<std::string, mindspore::tensor::MSTensor *> orig_output_tensor_map_;
};
} // namespace session
} // namespace mindspore

@ -22,6 +22,7 @@
typedef struct BatchNormParameter {
OpParameter op_parameter_;
float epsilon_;
float momentum_;
int unit_;
int units_;
int channel_;

@ -54,22 +54,22 @@ void FusedBatchNormFp32(const void *input, const void *scale, const void *offset
}
}
void FusedBatchNormFp32MeanVar(const float *input, float momentum, float *run_mean, float *run_var,
BatchNormParameter *param, float *save_mean, float *save_inv_var) {
void FusedBatchNormFp32MeanVar(const float *input, float *run_mean, float *run_var, BatchNormParameter *param,
float *save_mean, float *save_var) {
float N = (float)param->unit_;
for (int i = 0; i < param->unit_; i++) {
for (int f = 0; f < param->channel_; f++) {
int idx = i * param->channel_ + f;
run_mean[f] += input[idx];
run_var[f] += input[idx] * input[idx];
for (int c = 0; c < param->channel_; c++) {
int idx = i * param->channel_ + c;
run_mean[c] += input[idx];
run_var[c] += input[idx] * input[idx];
}
}
const float VN = (N > 1.0f) ? (N - 1.0f) : 1.0f;
for (int f = 0; f < param->channel_; f++) {
run_mean[f] = run_mean[f] / N;
run_var[f] = run_var[f] / VN - run_mean[f] * run_mean[f];
save_mean[f] = momentum * save_mean[f] + (1 - momentum) * run_mean[f];
const float inv_var = 1.f / sqrt(run_var[f] + param->epsilon_);
save_inv_var[f] = momentum * save_inv_var[f] + (1 - momentum) * inv_var;
for (int c = 0; c < param->channel_; c++) {
run_mean[c] = run_mean[c] / N;
run_var[c] = run_var[c] / VN - run_mean[c] * run_mean[c];
save_mean[c] = param->momentum_ * save_mean[c] + (1 - param->momentum_) * run_mean[c];
const float var = run_var[c];
save_var[c] = param->momentum_ * save_var[c] + (1 - param->momentum_) * var;
}
}

@ -28,8 +28,8 @@ void BatchNormFp32(const void *input, const void *mean, const void *variance, Ba
void FusedBatchNormFp32(const void *input, const void *scale, const void *offset, const void *mean,
const void *variance, BatchNormParameter *param, int task_id, void *output);
void FusedBatchNormFp32MeanVar(const float *input, float momentum, float *run_mean, float *run_var,
BatchNormParameter *param, float *save_mean, float *save_var);
void FusedBatchNormFp32MeanVar(const float *input, float *run_mean, float *run_var, BatchNormParameter *param,
float *save_mean, float *save_var);
#ifdef __cplusplus
}
#endif

@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP32_GRAD_OPTIMIZER_H_
#define MINDSPORE_LITE_NNACL_FP32_GRAD_OPTIMIZER_H_
#include "nnacl/op_base.h"
typedef struct ApplyMomentumParameter {
OpParameter op_parameter_;
bool use_locking_;
bool use_nesterov_;
float grad_scale_;
} ApplyMomentumParameter;
typedef struct SgdParameter {
OpParameter op_parameter_;
float dampening_;
bool use_nesterov_;
float weight_decay_;
} SgdParameter;
#endif // MINDSPORE_LITE_NNACL_FP32_GRAD_OPTIMIZER_H_

@ -20,10 +20,8 @@
static int is_a_ge_zero_and_a_lt_b(int a, int b) { return (unsigned)(a) < (unsigned)(b); }
void im2col_hwc(const float *in_data, float *data_col, ConvParameter *conv_param) {
const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_l_;
// const int pad_right = /*conv_param->pad_r_*/conv_param->pad_w_;
const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_u_;
// const int pad_down = /*conv_param->pad_d/*/conv_param->pad_h_;
const int pad_left = conv_param->pad_l_;
const int pad_up = conv_param->pad_u_;
const int stride_h = conv_param->stride_h_;
const int stride_w = conv_param->stride_w_;
@ -39,10 +37,11 @@ void im2col_hwc(const float *in_data, float *data_col, ConvParameter *conv_param
const int output_h = conv_param->output_h_;
const int output_w = conv_param->output_w_;
const int channels = conv_param->input_channel_ / conv_param->group_;
const int tot_channels = conv_param->input_channel_;
int /*channel,*/ kernel_row, kernel_col, output_rows, output_col;
int kernel_row, kernel_col, output_rows, output_col;
int row_stride_offset = 0;
@ -71,11 +70,9 @@ void im2col_hwc(const float *in_data, float *data_col, ConvParameter *conv_param
}
// output matrix is (kernel_h*kernel_w*channels)X(output_h*output_w)
void im2row_hwc(const float *in_data, float *data_row, ConvParameter *conv_param) {
const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_l_;
// const int pad_right = /*conv_param->pad_r_*/conv_param->pad_w_;
const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_u_;
// const int pad_down = /*conv_param->pad_d/*/conv_param->pad_h_;
void im2row_hwc(const float *in_data, float *data_row, ConvParameter *conv_param, bool transpose) {
const int pad_left = conv_param->pad_l_;
const int pad_up = conv_param->pad_u_;
const int stride_h = conv_param->stride_h_;
const int stride_w = conv_param->stride_w_;
@ -86,38 +83,67 @@ void im2row_hwc(const float *in_data, float *data_row, ConvParameter *conv_param
const int kernel_h = conv_param->kernel_h_;
const int kernel_w = conv_param->kernel_w_;
const int in_height = conv_param->input_h_;
const int in_width = conv_param->input_w_;
const int in_height = (transpose) ? conv_param->output_h_ : conv_param->input_h_;
const int in_width = (transpose) ? conv_param->output_w_ : conv_param->input_w_;
const int output_h = conv_param->output_h_;
const int output_w = conv_param->output_w_;
const int channels = conv_param->input_channel_ / conv_param->group_;
const int tot_channels = conv_param->input_channel_;
const int output_h = (transpose) ? conv_param->input_h_ : conv_param->output_h_;
const int output_w = (transpose) ? conv_param->input_w_ : conv_param->output_w_;
const int tot_channels = (transpose) ? conv_param->output_channel_ : conv_param->input_channel_;
const int channels = tot_channels / conv_param->group_;
int channel, kernel_row, kernel_col, output_rows, output_col;
for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
for (channel = 0; channel < channels; channel++) {
int input_row = -pad_up + kernel_row * dilation_h;
for (output_rows = output_h; output_rows; output_rows--) {
if (!is_a_ge_zero_and_a_lt_b(input_row, in_height)) {
for (output_col = output_w; output_col; output_col--) {
*(data_row++) = 0;
if (transpose) {
for (channel = 0; channel < channels; channel++) {
for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
int input_row = -pad_up + kernel_row * dilation_h;
for (output_rows = output_h; output_rows; output_rows--) {
if (!is_a_ge_zero_and_a_lt_b(input_row, in_height)) {
for (output_col = output_w; output_col; output_col--) {
*(data_row++) = 0;
}
} else {
int input_col = -pad_left + kernel_col * dilation_w;
for (output_col = output_w; output_col; output_col--) {
if (is_a_ge_zero_and_a_lt_b(input_col, in_width)) {
const int offset = (input_row * in_width + input_col) * tot_channels + channel;
*(data_row++) = in_data[offset];
} else {
*(data_row++) = 0;
}
input_col += stride_w;
}
}
} else {
int input_col = -pad_left + kernel_col * dilation_w;
for (output_col = output_w; output_col; output_col--) {
if (is_a_ge_zero_and_a_lt_b(input_col, in_width)) {
const int offset = (input_row * in_width + input_col) * tot_channels + channel;
*(data_row++) = in_data[offset];
} else {
input_row += stride_h;
}
}
}
}
} else {
for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
for (channel = 0; channel < channels; channel++) {
int input_row = -pad_up + kernel_row * dilation_h;
for (output_rows = output_h; output_rows; output_rows--) {
if (!is_a_ge_zero_and_a_lt_b(input_row, in_height)) {
for (output_col = output_w; output_col; output_col--) {
*(data_row++) = 0;
}
input_col += stride_w;
} else {
int input_col = -pad_left + kernel_col * dilation_w;
for (output_col = output_w; output_col; output_col--) {
if (is_a_ge_zero_and_a_lt_b(input_col, in_width)) {
const int offset = (input_row * in_width + input_col) * tot_channels + channel;
*(data_row++) = in_data[offset];
} else {
*(data_row++) = 0;
}
input_col += stride_w;
}
}
input_row += stride_h;
}
input_row += stride_h;
}
}
}
@ -125,10 +151,8 @@ void im2row_hwc(const float *in_data, float *data_row, ConvParameter *conv_param
}
void col2im_hwc(const float *data_col, float *data_im, ConvParameter *conv_param) {
const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_l_;
// const int pad_right = /*conv_param->pad_r_*/conv_param->pad_w_;
const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_u_;
// const int pad_down = /*conv_param->pad_d/*/conv_param->pad_h_;
const int pad_left = conv_param->pad_l_;
const int pad_up = conv_param->pad_u_;
const int stride_h = conv_param->stride_h_;
const int stride_w = conv_param->stride_w_;

@ -23,7 +23,7 @@
extern "C" {
#endif
void im2col_hwc(const float *in_data, float *data_col, ConvParameter *conv_param);
void im2row_hwc(const float *in_data, float *data_row, ConvParameter *conv_param);
void im2row_hwc(const float *in_data, float *data_row, ConvParameter *conv_param, bool transpose);
void col2im_hwc(const float *data_col, float *data_im, ConvParameter *conv_param);
#ifdef __cplusplus
}

@ -17,7 +17,7 @@
#include <float.h>
#include "nnacl/fp32_grad/pooling_grad.h"
void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param) {
void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id) {
int stride_w = pooling_param->stride_w_;
int stride_h = pooling_param->stride_h_;
int pad_w = pooling_param->pad_l_;
@ -41,7 +41,7 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter
for (uint16_t yh = 0; yh < output_h; yh++) {
for (uint16_t yw = 0; yw < output_w; yw++) {
for (uint16_t ic = 0; ic < channel; ic++) {
int idx = (yw + yh * output_w) * channel + ic; // (ic*in_h*in_w) + (in_w*yh) + yw;
int idx = (yw + yh * output_w) * channel + ic;
float delta = inPtr[idx] / kk;
for (int32_t kh = 0; kh < win_h; kh++) {
int xh = yh * stride_h + kh - pad_h;
@ -63,7 +63,7 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter
}
void MaxPoolingGrad(const float *input_ptr, const float *dx_ptr, const float *dy_ptr, float *output_ptr,
PoolingParameter *pooling_param) {
PoolingParameter *pooling_param, int task_id) {
int stride_w = pooling_param->stride_w_;
int stride_h = pooling_param->stride_h_;
int pad_w = pooling_param->pad_l_;

@ -22,9 +22,9 @@
#ifdef __cplusplus
extern "C" {
#endif
void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param);
void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id);
void MaxPoolingGrad(const float *input_ptr, const float *dx_ptr, const float *dy_ptr, float *output_ptr,
PoolingParameter *pooling_param);
PoolingParameter *pooling_param, int task_id);
#ifdef __cplusplus
}
#endif

@ -207,6 +207,7 @@ union PrimitiveType {
LshProjection,
HashtableLookup,
SkipGram,
DeConv2DGradFilter,
CustomPredict,
CustomNormalize,
CustomExtractFeatures,
@ -215,6 +216,7 @@ union PrimitiveType {
Rfft,
FftReal,
FftImag,
Sgd,
}
enum QuantType: int {

@ -407,6 +407,27 @@ table DeConv2D {
hasBias: bool = false;
activationType: ActivationType = 0;
}
table DeConv2DGradFilter {
format: Format = 0;
group: int;
channelIn: int;
channelOut: int;
kernelW: int;
kernelH: int;
strideW: int;
strideH: int;
padMode: PadMode;
padUp: int;
padDown: int;
padLeft: int;
padRight: int;
dilateW: int;
dilateH: int;
hasBias: bool = false;
activationType: ActivationType = 0;
}
table BNGrad {
eps : float;
momentum: float;
@ -884,6 +905,11 @@ table ApplyMomentum {
useNesterov: bool;
}
table Sgd {
weightDecay: float;
dampening: float;
useNesterov: bool;
}
table Where{
condition: [bool];

@ -45,7 +45,7 @@ int CompareRelativeOutput(float *output_data, std::string file_path) {
return 1;
}
size_t output_num = output_size / sizeof(float);
int error = CompareOutputRelativeData(output_data, ground_truth, output_num);
float error = CompareOutputRelativeData(output_data, ground_truth, output_num);
delete[] ground_truth;
if (error > 1e-4) {
return 1;

@ -18,6 +18,22 @@
#include <algorithm>
namespace mindspore::kernel {
void *LiteKernel::workspace_ = nullptr;
void LiteKernel::AllocWorkspace(size_t size) {
if (size == 0) return;
workspace_ = malloc(size);
if (workspace_ == nullptr) {
MS_LOG(ERROR) << "fail to alloc " << size;
}
}
void LiteKernel::FreeWorkspace() {
free(workspace_);
workspace_ = nullptr;
}
void LiteKernel::InitOutTensorRefCount() {
for (auto *tensor : this->out_tensors_) {
tensor->SetRefCount(this->out_kernels_.size());

@ -18,6 +18,7 @@
#define MINDSPORE_LITE_SRC_LITE_KERNEL_H_
#include <vector>
#include <string>
#include <memory>
#include "src/ops/primitive_c.h"
#include "src/common/utils.h"
#ifdef ENABLE_ARM
@ -145,6 +146,11 @@ class LiteKernel {
void set_desc(const KernelKey kernel_key) { desc_ = kernel_key; }
const mindspore::lite::PrimitiveC *GetPrimitive() const { return primitive_; }
void SetWorkspaceSize(size_t value) { workspace_size_ = value; }
size_t GetWorkspaceSize() { return workspace_size_; }
static void AllocWorkspace(size_t size);
static void FreeWorkspace();
void *GetWorkspace() { return workspace_; }
protected:
bool InferShapeDone() { return !(primitive_ != nullptr && !primitive_->GetInferFlag()) && true; }
@ -161,6 +167,8 @@ class LiteKernel {
std::vector<LiteKernel *> out_kernels_;
bool train_mode_ = false;
bool is_model_output_ = false;
size_t workspace_size_ = 0;
static void *workspace_;
};
class SubGraphKernel : public LiteKernel {

@ -17,6 +17,10 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
float ApplyMomentum::GetGradientScale() const { return this->primitive_->value.AsApplyMomentum()->gradientScale; }
bool ApplyMomentum::GetUseLocking() const { return this->primitive_->value.AsApplyMomentum()->useLocking; }
bool ApplyMomentum::GetUseNesterov() const { return this->primitive_->value.AsApplyMomentum()->useNesterov; }
int ApplyMomentum::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
@ -36,6 +40,10 @@ int ApplyMomentum::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePt
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
attr->gradientScale = GetValue<float>(prim.GetAttr("gradient_scale"));
attr->useLocking = GetValue<bool>(prim.GetAttr("use_locking"));
attr->useNesterov = GetValue<bool>(prim.GetAttr("use_nesterov"));
this->primitive_->value.value = attr.release();
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
@ -45,6 +53,10 @@ int ApplyMomentum::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePt
return RET_OK;
}
#else
float ApplyMomentum::GetGradientScale() const { return this->primitive_->value_as_ApplyMomentum()->gradientScale(); }
bool ApplyMomentum::GetUseLocking() const { return this->primitive_->value_as_ApplyMomentum()->useLocking(); }
bool ApplyMomentum::GetUseNesterov() const { return this->primitive_->value_as_ApplyMomentum()->useNesterov(); }
int ApplyMomentum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
@ -53,7 +65,7 @@ int ApplyMomentum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb
MS_LOG(ERROR) << "value_as_ApplyMomentum return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateApplyMomentum(*fbb);
auto val_offset = schema::CreateApplyMomentum(*fbb, attr->gradientScale(), attr->useLocking(), attr->useNesterov());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ApplyMomentum, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
@ -62,7 +74,7 @@ int ApplyMomentum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb
int ApplyMomentum::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
if (5 != inputs.size()) {
MS_LOG(ERROR) << "ApplyMomentum should have at 5 input tensors";
MS_LOG(ERROR) << "ApplyMomentum should have at least 5 input tensors";
return RET_ERROR;
}
@ -76,6 +88,7 @@ int ApplyMomentum::InferShape(std::vector<lite::Tensor *> inputs, std::vector<li
MS_ASSERT(out != nullptr);
out->set_data_type(inputs[0]->data_type());
out->SetFormat(inputs[0]->GetFormat());
out->set_shape({1});
}
return RET_OK;

@ -39,6 +39,9 @@ class ApplyMomentum : public PrimitiveC {
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
float GetGradientScale() const;
bool GetUseLocking() const;
bool GetUseNesterov() const;
};
} // namespace lite
} // namespace mindspore

@ -89,6 +89,7 @@ int BiasGrad::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> out
auto *out = outputs.front();
MS_ASSERT(in0 != nullptr);
MS_ASSERT(out != nullptr);
auto inshape = in0->shape();
int ndim = inshape.size();
for (int i = 0; i < ndim - 1; i++) {

@ -75,7 +75,7 @@ float BNGrad::GetEps() const { return this->primitive_->value_as_BNGrad()->eps()
float BNGrad::GetMomentum() const { return this->primitive_->value_as_BNGrad()->momentum(); }
#endif
int BNGrad::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
if (5 != inputs.size()) {
if (6 != inputs.size()) {
MS_LOG(ERROR) << "BNGrad should have five inputs";
return RET_ERROR;
}
@ -85,6 +85,7 @@ int BNGrad::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Ten
}
auto in = inputs[1];
auto scale = inputs[2];
outputs[0]->set_shape(in->shape());
outputs[1]->set_shape(scale->shape());
outputs[2]->set_shape(scale->shape());

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_H_
#define LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_H_
#ifndef MINDSPORE_LITE_SRC_OPS_BN_GRAD_H_
#define MINDSPORE_LITE_SRC_OPS_BN_GRAD_H_
#include <vector>
#include <set>
@ -44,4 +44,4 @@ class BNGrad : public PrimitiveC {
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_B_N_GRAD_INPUT_H_
#endif // MINDSPORE_LITE_SRC_OPS_BN_GRAD_H_

@ -73,5 +73,20 @@ float FusedBatchNorm::GetMomentum() const { return this->primitive_->value_as_Fu
int FusedBatchNorm::GetSpatial() const { return this->primitive_->value_as_FusedBatchNorm()->spatial(); }
#endif
int FusedBatchNorm::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
for (size_t i = 0; i < inputs_.size(); i++) {
if (outputs_.size() <= i) break;
outputs_.at(i)->set_shape(inputs_.at(i)->shape());
outputs_.at(i)->set_data_type(inputs_.at(i)->data_type());
outputs_.at(i)->SetFormat(inputs_.at(i)->GetFormat());
}
if (outputs_.size() > 5) {
outputs_.at(5)->set_data_type(inputs_.at(0)->data_type());
outputs_.at(5)->SetFormat(inputs_.at(0)->GetFormat());
outputs_.at(5)->set_shape({1});
}
return 0;
}
} // namespace lite
} // namespace mindspore

@ -39,6 +39,7 @@ class FusedBatchNorm : public PrimitiveC {
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
float GetEpsilon() const;
float GetMomentum() const;
int GetSpatial() const;

@ -145,7 +145,15 @@ int PoolingGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuf
#endif
int PoolingGrad::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
if (3 != inputs_.size()) {
MS_LOG(ERROR) << "Pooling Grad Filter should have 3 inputs";
return RET_ERROR;
}
if (1 != outputs_.size()) {
MS_LOG(ERROR) << "Pooling Grad Filter should have one output";
return RET_ERROR;
}
auto input = inputs_.at(0);
MS_ASSERT(input != nullptr);
int input_h = input->shape().at(1);

@ -151,6 +151,7 @@
#include "src/ops/depend.h"
#include "src/ops/flatten_grad.h"
#include "src/ops/log_grad.h"
#include "src/ops/sgd.h"
#endif
namespace mindspore {
@ -384,7 +385,7 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<Dequant>(prim, inputs, quantType);
} else if (op_type == "Flatten") {
return NewPrimitiveC<Flatten>(prim, inputs, quantType);
} else if (op_type == "FusedBatchNorm") {
} else if ((op_type == "FusedBatchNorm") || (op_type == "FusedBatchNormEx")) {
return NewPrimitiveC<FusedBatchNorm>(prim, inputs, quantType);
} else if (op_type == "make_tuple") {
return NewPrimitiveC<MakeTuple>(prim, inputs, quantType);
@ -452,7 +453,7 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<Conv2DGradFilter>(prim, inputs, quantType);
} else if (op_type == "Conv2DBackpropInput") {
return NewPrimitiveC<Conv2DGradInput>(prim, inputs, quantType);
} else if (op_type == "BatchNormGrad") {
} else if ((op_type == "BatchNormGrad") || (op_type == "FusedBatchNormGradEx")) {
return NewPrimitiveC<BNGrad>(prim, inputs, quantType);
} else if (op_type == "FlattenGrad") {
return NewPrimitiveC<FlattenGrad>(prim, inputs, quantType);
@ -460,6 +461,10 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<BNGrad>(prim, inputs, quantType);
} else if (op_type == "Tile") {
return NewPrimitiveC<Tile>(prim, inputs, quantType);
} else if (op_type == "PowerGrad") {
return NewPrimitiveC<PowerGrad>(prim, inputs, quantType);
} else if (op_type == "SGD") {
return NewPrimitiveC<Sgd>(prim, inputs, quantType);
#else
} else if (op_type == "Conv2DBackpropInput") {
return NewPrimitiveC<DeConv2D>(prim, inputs, quantType);
@ -731,6 +736,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new NegGrad(primitive);
case schema::PrimitiveType_LogGrad:
return new LogGrad(primitive);
case schema::PrimitiveType_Sgd:
return new Sgd(primitive);
#endif
default:
@ -995,6 +1002,8 @@ PrimitiveC *PrimitiveC::Create(const schema::Primitive *primitive) {
return NewPrimitiveC<NegGrad>(primitive);
case schema::PrimitiveType_LogGrad:
return NewPrimitiveC<LogGrad>(primitive);
case schema::PrimitiveType_Sgd:
return NewPrimitiveC<Sgd>(primitive);
#endif
default:
MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type);

@ -0,0 +1,97 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/sgd.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
float Sgd::GetWeightDecay() const { return this->primitive_->value.AsSgd()->weightDecay; }
float Sgd::GetDampening() const { return this->primitive_->value.AsSgd()->dampening; }
bool Sgd::GetUseNesterov() const { return this->primitive_->value.AsSgd()->useNesterov; }
int Sgd::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Sgd;
}
if (this->primitive_->value.type != schema::PrimitiveType_Sgd) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = std::make_unique<schema::SgdT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
attr->weightDecay = GetValue<float>(prim.GetAttr("weight_decay"));
attr->dampening = GetValue<float>(prim.GetAttr("dampening"));
attr->useNesterov = GetValue<bool>(prim.GetAttr("nesterov"));
this->primitive_->value.value = attr.release();
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
}
return RET_OK;
}
#else
float Sgd::GetWeightDecay() const { return this->primitive_->value_as_Sgd()->weightDecay(); }
float Sgd::GetDampening() const { return this->primitive_->value_as_Sgd()->dampening(); }
bool Sgd::GetUseNesterov() const { return this->primitive_->value_as_Sgd()->useNesterov(); }
int Sgd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Sgd();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Sgd return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateSgd(*fbb, attr->weightDecay(), attr->dampening(), attr->useNesterov());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Sgd, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
#endif
int Sgd::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
if (6 != inputs.size()) {
MS_LOG(ERROR) << "Sgd should have at least 6 input tensors";
return RET_ERROR;
}
if (inputs[0]->ElementsNum() != inputs[1]->ElementsNum() || inputs[0]->ElementsNum() != inputs[3]->ElementsNum() ||
inputs[2]->ElementsNum() != 1 || inputs[4]->ElementsNum() != 1) {
MS_LOG(ERROR) << "error input data size!";
return RET_ERROR;
}
if (!outputs.empty()) {
auto *out = outputs.front();
MS_ASSERT(out != nullptr);
out->set_data_type(inputs[0]->data_type());
out->SetFormat(inputs[0]->GetFormat());
out->set_shape({1});
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save