!7432 Adam + Sparse softmax and bug fix

Merge pull request !7432 from yonibaehr/export
pull/7432/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 4843f6aba0

@ -117,6 +117,14 @@ int HSwish(const float *src, int length, float *dst) {
return NNACL_OK;
}
int HSigmoid(const float *src, int length, float *dst) {
for (int i = 0; i < length; ++i) {
float relu6 = MSMIN(MSMAX(src[i] + 3, 0), 6);
dst[i] = relu6 / 6;
}
return NNACL_OK;
}
int HardTanh(const float *src, int length, float *dst, float min_val, float max_val) {
if (max_val <= min_val) {
return NNACL_ERR;

@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_ACTIVATION_H_
#define MINDSPORE_LITE_NNACL_ACTIVATION_H_
#ifndef MINDSPORE_LITE_NNACL_FP32_ACTIVATION_H_
#define MINDSPORE_LITE_NNACL_FP32_ACTIVATION_H_
#include <math.h>
#include "nnacl/op_base.h"
@ -36,9 +36,10 @@ int Fp32Relu6(const float *src, int length, float *dst);
int LRelu(const float *src, int length, float *dst, float alpha);
int Sigmoid(const float *src, int length, float *dst);
int Tanh(const float *src, int length, float *dst);
int HSigmoid(const float *src, int length, float *dst);
int HSwish(const float *src, int length, float *dst);
int HardTanh(const float *src, int length, float *dst, float min_val, float max_val);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_ACTIVATION_H_
#endif // MINDSPORE_LITE_NNACL_FP32_ACTIVATION_H_

@ -31,14 +31,14 @@ int OneHot(const int *indices, float *output, const OneHotParameter *one_hot_par
int i, j, k;
for (i = tid; i < outer_size; i += thread_num) {
float *output_ptr = output + i * depth * inner_size;
for (k = 0; k < depth; k++) {
for (j = 0; j < inner_size; j++) {
for (k = 0; k < inner_size; k++) {
int index = indices[i * inner_size + k];
for (j = 0; j < depth; j++) {
*output_ptr = off_value;
int index = indices[i * inner_size + j];
if (index >= depth) {
return NNACL_ERRCODE_INDEX_OUT_OF_RANGE;
}
if (index == k) {
if (index == j) {
*output_ptr = on_value;
}
output_ptr++;

@ -15,27 +15,52 @@
*/
#include "nnacl/fp32_grad/gemm.h"
#include <string.h>
static void gemm_not_trana_not_tranb(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb,
float *mat_c, int ldc) {
const int block_size = 4;
int block_mod = N % block_size;
int block_c4 = N - block_mod;
static void gemm_nn(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_B, int ldb, float *mat_c,
int ldc) {
int i, j, k;
for (i = 0; i < M; ++i) {
for (k = 0; k < K; ++k) {
float a = alpha * mat_a[i * lda + k];
for (j = 0; j < N; ++j) {
mat_c[i * ldc + j] += a * mat_B[k * ldb + j];
for (j = 0; j < block_c4; j += block_size) {
float *b = &mat_b[k * ldb + j];
float *c = &mat_c[i * ldc + j];
c[0] += a * b[0];
c[1] += a * b[1];
c[2] += a * b[2];
c[3] += a * b[3];
}
for (; j < N; ++j) {
mat_c[i * ldc + j] += a * mat_b[k * ldb + j];
}
}
}
}
static void gemm_nt(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, float *mat_c,
int ldc) {
static void gemm_not_trana_tranb(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb,
float *mat_c, int ldc) {
const int block_size = 4;
int block_mod = K % block_size;
int block_c4 = K - block_mod;
int i, j, k;
for (i = 0; i < M; ++i) {
for (j = 0; j < N; ++j) {
float sum = 0;
for (k = 0; k < K; ++k) {
for (k = 0; k < block_c4; k += block_size) {
float *a = &mat_a[i * lda + k];
float *b = &mat_b[j * ldb + k];
sum += alpha * a[0] * b[0];
sum += alpha * a[1] * b[1];
sum += alpha * a[2] * b[2];
sum += alpha * a[3] * b[3];
}
for (; k < K; ++k) {
sum += alpha * mat_a[i * lda + k] * mat_b[j * ldb + k];
}
mat_c[i * ldc + j] += sum;
@ -43,23 +68,85 @@ static void gemm_nt(int M, int N, int K, float alpha, float *mat_a, int lda, flo
}
}
static void gemm_tn(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, float *mat_c,
int ldc) {
static void gemm_trana_not_tranb(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb,
float *mat_c, int ldc) {
const int block_size = 4;
int block_mod = N % block_size;
int block_c4 = N - block_mod;
int i, j, k;
for (i = 0; i < M; ++i) {
for (k = 0; k < K; ++k) {
float a = alpha * mat_a[k * lda + i];
for (j = 0; j < N; ++j) {
for (j = 0; j < block_c4; j += block_size) {
float *b = &mat_b[k * ldb + j];
float *c = &mat_c[i * ldc + j];
c[0] += a * b[0];
c[1] += a * b[1];
c[2] += a * b[2];
c[3] += a * b[3];
}
for (; j < N; ++j) {
mat_c[i * ldc + j] += a * mat_b[k * ldb + j];
}
}
}
}
static void gemm_tt(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb, float *mat_c,
int ldc) {
static void gemm_trana_tranb(int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b, int ldb,
float *mat_c, int ldc) {
int i, j, k;
for (i = 0; i < M; ++i) {
const int block_size = 4;
int k_block_mod = K % block_size;
int k_block_c4 = K - k_block_mod;
int m_block_mod = M % block_size;
int m_block_c4 = M - m_block_mod;
for (i = 0; i < m_block_c4; i += block_size) {
for (j = 0; j < N; ++j) {
float sum0 = 0;
float sum1 = 0;
float sum2 = 0;
float sum3 = 0;
for (k = 0; k < k_block_c4; k += block_size) {
float *b = &mat_b[j * ldb + k];
sum0 += alpha * mat_a[i + k * lda] * b[0];
sum0 += alpha * mat_a[i + (k + 1) * lda] * b[1];
sum0 += alpha * mat_a[i + (k + 2) * lda] * b[2];
sum0 += alpha * mat_a[i + (k + 3) * lda] * b[3];
sum1 += alpha * mat_a[i + 1 + k * lda] * b[0];
sum1 += alpha * mat_a[i + 1 + (k + 1) * lda] * b[1];
sum1 += alpha * mat_a[i + 1 + (k + 2) * lda] * b[2];
sum1 += alpha * mat_a[i + 1 + (k + 3) * lda] * b[3];
sum2 += alpha * mat_a[i + 2 + k * lda] * b[0];
sum2 += alpha * mat_a[i + 2 + (k + 1) * lda] * b[1];
sum2 += alpha * mat_a[i + 2 + (k + 2) * lda] * b[2];
sum2 += alpha * mat_a[i + 2 + (k + 3) * lda] * b[3];
sum3 += alpha * mat_a[i + 3 + k * lda] * b[0];
sum3 += alpha * mat_a[i + 3 + (k + 1) * lda] * b[1];
sum3 += alpha * mat_a[i + 3 + (k + 2) * lda] * b[2];
sum3 += alpha * mat_a[i + 3 + (k + 3) * lda] * b[3];
}
for (; k < K; ++k) {
float *b = &mat_b[j * ldb + k];
sum0 += alpha * mat_a[i + (k * lda)] * b[0];
sum1 += alpha * mat_a[i + 1 + (k * lda)] * b[0];
sum2 += alpha * mat_a[i + 2 + (k * lda)] * b[0];
sum3 += alpha * mat_a[i + 3 + (k * lda)] * b[0];
}
mat_c[i * ldc + j] += sum0;
mat_c[(i + 1) * ldc + j] += sum1;
mat_c[(i + 2) * ldc + j] += sum2;
mat_c[(i + 3) * ldc + j] += sum3;
}
}
// no more block of 4x4
for (; i < M; ++i) {
for (j = 0; j < N; ++j) {
float sum = 0;
for (k = 0; k < K; ++k) {
@ -74,34 +161,37 @@ static void gemm_tt(int M, int N, int K, float alpha, float *mat_a, int lda, flo
// M - number of rows of matrix a
// N - number of cols of matrix b
// K - number of cols of matrix a
// lda - fast dim of matrix a
// ldb - fast dim of matrix b
// ldc - fast dim of matrix c
void gemm(int transpose_a, int transpose_b, int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b,
int ldb, float beta, float *mat_c, int ldc) {
if (beta >= 0.f && beta <= 0.f) {
for (int i = 0; i < M; ++i) {
for (int j = 0; j < N; ++j) {
mat_c[i * ldc + j] = 0;
}
}
memset(mat_c, 0, M * N * sizeof(float));
} else if (beta < 1.f || beta > 1.f) {
for (int i = 0; i < M; ++i) {
for (int j = 0; j < N; ++j) {
mat_c[i * ldc + j] *= beta;
}
const int block_size = 4;
const int size = M * N;
int block_mod = size % block_size;
int block_c4 = size - block_mod;
int i;
for (i = 0; i < block_c4; i += block_size) {
float *c = &mat_c[i];
c[0] *= beta;
c[1] *= beta;
c[2] *= beta;
c[3] *= beta;
}
}
int t;
for (t = 0; t < M; ++t) {
if (!transpose_a && !transpose_b) {
gemm_nn(1, N, K, alpha, mat_a + t * lda, lda, mat_b, ldb, mat_c + t * ldc, ldc);
} else if (transpose_a && !transpose_b) {
gemm_tn(1, N, K, alpha, mat_a + t, lda, mat_b, ldb, mat_c + t * ldc, ldc);
} else if (!transpose_a && transpose_b) {
gemm_nt(1, N, K, alpha, mat_a + t * lda, lda, mat_b, ldb, mat_c + t * ldc, ldc);
} else {
gemm_tt(1, N, K, alpha, mat_a + t, lda, mat_b, ldb, mat_c + t * ldc, ldc);
for (; i < size; ++i) {
mat_c[i] *= beta;
}
}
if (transpose_a && transpose_b) {
gemm_trana_tranb(M, N, K, alpha, mat_a, lda, mat_b, ldb, mat_c, ldc);
} else if (!transpose_a && !transpose_b) {
gemm_not_trana_not_tranb(M, N, K, alpha, mat_a, lda, mat_b, ldb, mat_c, ldc);
} else if (!transpose_a && transpose_b) {
gemm_not_trana_tranb(M, N, K, alpha, mat_a, lda, mat_b, ldb, mat_c, ldc);
} else {
gemm_trana_not_tranb(M, N, K, alpha, mat_a, lda, mat_b, ldb, mat_c, ldc);
}
}

@ -21,7 +21,6 @@
typedef struct ApplyMomentumParameter {
OpParameter op_parameter_;
bool use_locking_;
bool use_nesterov_;
float grad_scale_;
} ApplyMomentumParameter;
@ -33,4 +32,9 @@ typedef struct SgdParameter {
float weight_decay_;
} SgdParameter;
typedef struct AdamParameter {
OpParameter op_parameter_;
bool use_nesterov_;
} AdamParameter;
#endif // MINDSPORE_LITE_NNACL_FP32_GRAD_OPTIMIZER_H_

@ -182,7 +182,7 @@ union PrimitiveType {
Conv2DGradInput,
PoolingGrad,
BNGrad,
BNGradInput,
Assign,
ApplyMomentum,
BiasGrad,
SoftmaxCrossEntropy,
@ -217,6 +217,8 @@ union PrimitiveType {
FftReal,
FftImag,
Sgd,
Adam,
GroupConv2DGradInput,
}
enum QuantType: int {

@ -224,7 +224,29 @@ table Conv2DGradInput {
dilateH: int;
hasBias: bool = false;
activationType: ActivationType = 0;
}table FusedBatchNorm {
}
table GroupConv2DGradInput {
format: Format = 0;
group: int;
channelIn: int;
channelOut: int;
kernelW: int;
kernelH: int;
strideW: int;
strideH: int;
padMode: PadMode;
padUp: int;
padDown: int;
padLeft: int;
padRight: int;
dilateW: int;
dilateH: int;
hasBias: bool = false;
activationType: ActivationType = 0;
}
table FusedBatchNorm {
epsilon: float = 0.00001; // eg. epsilon=0.001
momentum: float = 0.9;
spatial: int = 1;
@ -901,7 +923,6 @@ table TupleGetItem {
table ApplyMomentum {
gradientScale: float;
useLocking: bool;
useNesterov: bool;
}
@ -911,6 +932,14 @@ table Sgd {
useNesterov: bool;
}
table Adam {
useNesterov: bool;
}
table Assign {
}
table Where{
condition: [bool];
}

@ -50,6 +50,10 @@ int Activation::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr>
attr->type = schema::ActivationType_SIGMOID;
} else if (prim.name() == "ReLU6") {
attr->type = schema::ActivationType_RELU6;
} else if (prim.name() == "HSwish") {
attr->type = schema::ActivationType_HSWISH;
} else if (prim.name() == "HSigmoid") {
attr->type = schema::ActivationType_HSIGMOID;
}
this->primitive_->value.value = attr.release();
if (this->primitive_->value.value == nullptr) {

@ -43,8 +43,12 @@ int ActivationGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodeP
attr->type = schema::ActivationType_RELU;
} else if (prim.name() == "SigmoidGrad") {
attr->type = schema::ActivationType_SIGMOID;
} else if (prim.name() == "Relu6Grad") {
} else if (prim.name() == "ReLU6Grad") {
attr->type = schema::ActivationType_RELU6;
} else if (prim.name() == "HSigmoidGrad") {
attr->type = schema::ActivationType_HSIGMOID;
} else if (prim.name() == "HSwishGrad") {
attr->type = schema::ActivationType_HSWISH;
}
attr->alpha = 0; // alpha;
this->primitive_->value.value = attr.release();

@ -0,0 +1,91 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/adam.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
bool Adam::GetUseNesterov() const { return this->primitive_->value.AsAdam()->useNesterov; }
int Adam::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Adam;
}
if (this->primitive_->value.type != schema::PrimitiveType_Adam) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = std::make_unique<schema::AdamT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
attr->useNesterov = GetValue<bool>(prim.GetAttr("use_nesterov"));
this->primitive_->value.value = attr.release();
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
}
return RET_OK;
}
#else
bool Adam::GetUseNesterov() const { return this->primitive_->value_as_Adam()->useNesterov(); }
int Adam::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Adam();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Adam return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateAdam(*fbb, attr->useNesterov());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Adam, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
#endif
int Adam::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
if (10 != inputs.size()) {
MS_LOG(ERROR) << "Adam should have at least 8 input tensors";
return RET_ERROR;
}
if (inputs[0]->ElementsNum() != inputs[1]->ElementsNum() || inputs[0]->ElementsNum() != inputs[2]->ElementsNum() ||
inputs[0]->ElementsNum() != inputs[9]->ElementsNum() || inputs[3]->ElementsNum() != 1 ||
inputs[4]->ElementsNum() != 1 || inputs[5]->ElementsNum() != 1 || inputs[6]->ElementsNum() != 1 ||
inputs[7]->ElementsNum() != 1 || inputs[8]->ElementsNum() != 1) {
MS_LOG(ERROR) << "error input data size!";
return RET_ERROR;
}
if (!outputs.empty()) {
auto *out = outputs.front();
MS_ASSERT(out != nullptr);
out->set_data_type(inputs[0]->data_type());
out->SetFormat(inputs[0]->GetFormat());
out->set_shape({1});
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -0,0 +1,47 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_OPS_ADAM_H_
#define MINDSPORE_LITE_SRC_OPS_ADAM_H_
#include <vector>
#include <set>
#include <cmath>
#include <memory>
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class Adam : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Adam, PrimitiveC);
Adam() = default;
explicit Adam(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
Adam() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
bool GetUseNesterov() const;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_OPS_ADAM_H_

@ -82,8 +82,11 @@ int AddN::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs
if (!GetInferFlag()) {
return RET_OK;
}
output->set_shape(input->shape());
// make sure all elements have the same size or 1 (broadcasting) in all dimensions
for (size_t i = 1; i < inputs.size(); ++i) {
if (inputs.at(i)->shape() != inputs.at(0)->shape()) {
if (inputs.at(i)->shape().size() != inputs.at(0)->shape().size()) {
MS_LOG(ERROR) << "AddN inputs shape is not equal!";
return RET_INPUT_TENSOR_ERROR;
}
@ -93,7 +96,22 @@ int AddN::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs
}
}
output->set_shape(input->shape());
for (size_t d = 0; d < input->shape().size(); ++d) {
int max_dim = input->shape().at(d);
for (size_t i = 1; i < inputs.size(); ++i) {
if (inputs.at(i)->shape().at(d) > max_dim) {
max_dim = inputs.at(i)->shape().at(d);
}
}
for (size_t i = 1; i < inputs.size(); ++i) {
if ((inputs.at(0)->shape().at(d) != max_dim) && (inputs.at(0)->shape().at(d) != 1)) {
MS_LOG(ERROR) << "AddN inputs shape is not equal!";
return RET_INPUT_TENSOR_ERROR;
}
}
output->shape()[d] = max_dim; // set the biggest dimension in the output tensor
}
return RET_OK;
}
} // namespace lite

@ -18,7 +18,6 @@ namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
float ApplyMomentum::GetGradientScale() const { return this->primitive_->value.AsApplyMomentum()->gradientScale; }
bool ApplyMomentum::GetUseLocking() const { return this->primitive_->value.AsApplyMomentum()->useLocking; }
bool ApplyMomentum::GetUseNesterov() const { return this->primitive_->value.AsApplyMomentum()->useNesterov; }
int ApplyMomentum::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
@ -41,7 +40,6 @@ int ApplyMomentum::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePt
return RET_ERROR;
}
attr->gradientScale = GetValue<float>(prim.GetAttr("gradient_scale"));
attr->useLocking = GetValue<bool>(prim.GetAttr("use_locking"));
attr->useNesterov = GetValue<bool>(prim.GetAttr("use_nesterov"));
this->primitive_->value.value = attr.release();
@ -54,7 +52,6 @@ int ApplyMomentum::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePt
}
#else
float ApplyMomentum::GetGradientScale() const { return this->primitive_->value_as_ApplyMomentum()->gradientScale(); }
bool ApplyMomentum::GetUseLocking() const { return this->primitive_->value_as_ApplyMomentum()->useLocking(); }
bool ApplyMomentum::GetUseNesterov() const { return this->primitive_->value_as_ApplyMomentum()->useNesterov(); }
int ApplyMomentum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
@ -65,7 +62,7 @@ int ApplyMomentum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb
MS_LOG(ERROR) << "value_as_ApplyMomentum return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateApplyMomentum(*fbb, attr->gradientScale(), attr->useLocking(), attr->useNesterov());
auto val_offset = schema::CreateApplyMomentum(*fbb, attr->gradientScale(), attr->useNesterov());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ApplyMomentum, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;

@ -40,7 +40,6 @@ class ApplyMomentum : public PrimitiveC {
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
float GetGradientScale() const;
bool GetUseLocking() const;
bool GetUseNesterov() const;
};
} // namespace lite

@ -0,0 +1,82 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/assign.h"
#include <memory>
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Assign::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Assign;
}
if (this->primitive_->value.type != schema::PrimitiveType_Assign) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
this->primitive_->value.value = new (std::nothrow) schema::AssignT();
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int Assign::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Assign();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Assign return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateAssign(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Assign, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
#endif
int Assign::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
if (2 != inputs.size()) {
MS_LOG(ERROR) << "Assign should have at least 5 input tensors";
return RET_ERROR;
}
if (inputs[0]->ElementsNum() != inputs[1]->ElementsNum()) {
MS_LOG(ERROR) << "error input data size!";
return RET_ERROR;
}
if (!outputs.empty()) {
auto *out = outputs.front();
MS_ASSERT(out != nullptr);
out->set_data_type(inputs[0]->data_type());
out->SetFormat(inputs[0]->GetFormat());
out->set_shape({1});
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -0,0 +1,43 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_OPS_ASSIGN_H_
#define MINDSPORE_LITE_SRC_OPS_ASSIGN_H_
#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class Assign : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Assign, PrimitiveC);
Assign() = default;
explicit Assign(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
Assign() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_OPS_ASSIGN_H_

@ -45,8 +45,8 @@ int BNGrad::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
}
attr->momentum = GetValue<float>(prim.GetAttr("momentum"));
// FusedBatchNormGrad dows not get this attribute
if (prim.GetAttr("eps") != nullptr) {
attr->eps = GetValue<float>(prim.GetAttr("eps"));
if (prim.GetAttr("epsilon") != nullptr) {
attr->eps = GetValue<float>(prim.GetAttr("epsilon"));
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {

@ -15,7 +15,7 @@
*/
#include "src/ops/conv2d_grad_input.h"
#include "src/ops/group_conv2d_grad_input.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
@ -86,6 +86,9 @@ int Conv2DGradInput::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
return RET_ERROR;
}
attr->group = GetValue<int>(prim.GetAttr("group"));
if (attr->group > 1) {
this->primitive_->value.type = schema::PrimitiveType_GroupConv2DGradInput;
}
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
if (format == "NCHW") {
attr->format = schema::Format_NCHW;

@ -26,6 +26,30 @@ void Exp::SetShift(float shift) { this->primitive_->value.AsExp()->shift = shift
float Exp::GetBase() const { return this->primitive_->value.AsExp()->base; }
float Exp::GetScale() const { return this->primitive_->value.AsExp()->scale; }
float Exp::GetShift() const { return this->primitive_->value.AsExp()->shift; }
int Exp::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Exp;
}
if (this->primitive_->value.type != schema::PrimitiveType_Exp) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
this->primitive_->value.value = new (std::nothrow) schema::ExpT();
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int Exp::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {

@ -33,6 +33,7 @@ class Exp : public PrimitiveC {
void SetBase(float base);
void SetShift(float shift);
void SetScale(float scale);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
Exp() = default;

@ -0,0 +1,172 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/group_conv2d_grad_input.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int GroupConv2DGradInput::GetFormat() const { return this->primitive_->value.AsGroupConv2DGradInput()->format; }
int GroupConv2DGradInput::GetGroup() const { return this->primitive_->value.AsGroupConv2DGradInput()->group; }
int GroupConv2DGradInput::GetChannelIn() const { return this->primitive_->value.AsGroupConv2DGradInput()->channelIn; }
int GroupConv2DGradInput::GetChannelOut() const { return this->primitive_->value.AsGroupConv2DGradInput()->channelOut; }
int GroupConv2DGradInput::GetKernelW() const { return this->primitive_->value.AsGroupConv2DGradInput()->kernelW; }
int GroupConv2DGradInput::GetKernelH() const { return this->primitive_->value.AsGroupConv2DGradInput()->kernelH; }
int GroupConv2DGradInput::GetStrideW() const { return this->primitive_->value.AsGroupConv2DGradInput()->strideW; }
int GroupConv2DGradInput::GetStrideH() const { return this->primitive_->value.AsGroupConv2DGradInput()->strideH; }
int GroupConv2DGradInput::GetPadMode() const { return this->primitive_->value.AsGroupConv2DGradInput()->padMode; }
int GroupConv2DGradInput::GetPadUp() const { return this->primitive_->value.AsGroupConv2DGradInput()->padUp; }
int GroupConv2DGradInput::GetPadDown() const { return this->primitive_->value.AsGroupConv2DGradInput()->padDown; }
int GroupConv2DGradInput::GetPadLeft() const { return this->primitive_->value.AsGroupConv2DGradInput()->padLeft; }
int GroupConv2DGradInput::GetPadRight() const { return this->primitive_->value.AsGroupConv2DGradInput()->padRight; }
int GroupConv2DGradInput::GetDilateW() const { return this->primitive_->value.AsGroupConv2DGradInput()->dilateW; }
int GroupConv2DGradInput::GetDilateH() const { return this->primitive_->value.AsGroupConv2DGradInput()->dilateH; }
bool GroupConv2DGradInput::GetHasBias() const { return this->primitive_->value.AsGroupConv2DGradInput()->hasBias; }
int GroupConv2DGradInput::GetActivationType() const {
return this->primitive_->value.AsGroupConv2DGradInput()->activationType;
}
void GroupConv2DGradInput::SetFormat(int format) {
this->primitive_->value.AsGroupConv2DGradInput()->format = (schema::Format)format;
}
void GroupConv2DGradInput::SetGroup(int group) { this->primitive_->value.AsGroupConv2DGradInput()->group = group; }
void GroupConv2DGradInput::SetChannelIn(int channel_in) {
this->primitive_->value.AsGroupConv2DGradInput()->channelIn = channel_in;
}
void GroupConv2DGradInput::SetChannelOut(int channel_out) {
this->primitive_->value.AsGroupConv2DGradInput()->channelOut = channel_out;
}
void GroupConv2DGradInput::SetKernelW(int kernel_w) {
this->primitive_->value.AsGroupConv2DGradInput()->kernelW = kernel_w;
}
void GroupConv2DGradInput::SetKernelH(int kernel_h) {
this->primitive_->value.AsGroupConv2DGradInput()->kernelH = kernel_h;
}
void GroupConv2DGradInput::SetStrideW(int stride_w) {
this->primitive_->value.AsGroupConv2DGradInput()->strideW = stride_w;
}
void GroupConv2DGradInput::SetStrideH(int stride_h) {
this->primitive_->value.AsGroupConv2DGradInput()->strideH = stride_h;
}
void GroupConv2DGradInput::SetPadMode(int pad_mode) {
this->primitive_->value.AsGroupConv2DGradInput()->padMode = (schema::PadMode)pad_mode;
}
void GroupConv2DGradInput::SetPadUp(int pad_up) { this->primitive_->value.AsGroupConv2DGradInput()->padUp = pad_up; }
void GroupConv2DGradInput::SetPadDown(int pad_down) {
this->primitive_->value.AsGroupConv2DGradInput()->padDown = pad_down;
}
void GroupConv2DGradInput::SetPadLeft(int pad_left) {
this->primitive_->value.AsGroupConv2DGradInput()->padLeft = pad_left;
}
void GroupConv2DGradInput::SetPadRight(int pad_right) {
this->primitive_->value.AsGroupConv2DGradInput()->padRight = pad_right;
}
void GroupConv2DGradInput::SetDilateW(int dilate_w) {
this->primitive_->value.AsGroupConv2DGradInput()->dilateW = dilate_w;
}
void GroupConv2DGradInput::SetDilateH(int dilate_h) {
this->primitive_->value.AsGroupConv2DGradInput()->dilateH = dilate_h;
}
void GroupConv2DGradInput::SetHasBias(bool has_bias) {
this->primitive_->value.AsGroupConv2DGradInput()->hasBias = has_bias;
}
void GroupConv2DGradInput::SetActivationType(int activation_type) {
this->primitive_->value.AsGroupConv2DGradInput()->activationType = (schema::ActivationType)activation_type;
}
#else
int GroupConv2DGradInput::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_GroupConv2DGradInput();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_GroupConv2DGradInput return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateGroupConv2DGradInput(
*fbb, attr->format(), attr->group(), attr->channelIn(), attr->channelOut(), attr->kernelW(), attr->kernelH(),
attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
attr->padRight(), attr->dilateW(), attr->dilateH(), attr->hasBias(), attr->activationType());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_GroupConv2DGradInput, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int GroupConv2DGradInput::GetFormat() const { return this->primitive_->value_as_GroupConv2DGradInput()->format(); }
int GroupConv2DGradInput::GetGroup() const { return this->primitive_->value_as_GroupConv2DGradInput()->group(); }
int GroupConv2DGradInput::GetChannelIn() const {
return this->primitive_->value_as_GroupConv2DGradInput()->channelIn();
}
int GroupConv2DGradInput::GetChannelOut() const {
return this->primitive_->value_as_GroupConv2DGradInput()->channelOut();
}
int GroupConv2DGradInput::GetKernelW() const { return this->primitive_->value_as_GroupConv2DGradInput()->kernelW(); }
int GroupConv2DGradInput::GetKernelH() const { return this->primitive_->value_as_GroupConv2DGradInput()->kernelH(); }
int GroupConv2DGradInput::GetStrideW() const { return this->primitive_->value_as_GroupConv2DGradInput()->strideW(); }
int GroupConv2DGradInput::GetStrideH() const { return this->primitive_->value_as_GroupConv2DGradInput()->strideH(); }
int GroupConv2DGradInput::GetPadMode() const { return this->primitive_->value_as_GroupConv2DGradInput()->padMode(); }
int GroupConv2DGradInput::GetPadUp() const { return this->primitive_->value_as_GroupConv2DGradInput()->padUp(); }
int GroupConv2DGradInput::GetPadDown() const { return this->primitive_->value_as_GroupConv2DGradInput()->padDown(); }
int GroupConv2DGradInput::GetPadLeft() const { return this->primitive_->value_as_GroupConv2DGradInput()->padLeft(); }
int GroupConv2DGradInput::GetPadRight() const { return this->primitive_->value_as_GroupConv2DGradInput()->padRight(); }
int GroupConv2DGradInput::GetDilateW() const { return this->primitive_->value_as_GroupConv2DGradInput()->dilateW(); }
int GroupConv2DGradInput::GetDilateH() const { return this->primitive_->value_as_GroupConv2DGradInput()->dilateH(); }
bool GroupConv2DGradInput::GetHasBias() const { return this->primitive_->value_as_GroupConv2DGradInput()->hasBias(); }
int GroupConv2DGradInput::GetActivationType() const {
return this->primitive_->value_as_GroupConv2DGradInput()->activationType();
}
#endif
int GroupConv2DGradInput::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) {
if (3 != inputs.size()) {
MS_LOG(ERROR) << "Conv2d Grad Input should have 3 inputs";
return RET_ERROR;
}
if (1 != outputs.size()) {
MS_LOG(ERROR) << "Conv2d Grad input should have one output";
return RET_ERROR;
}
auto *in0 = inputs.at(0);
auto *in = inputs.at(2);
MS_ASSERT(out != nullptr);
std::vector<int> output_shape;
int *out_shape = reinterpret_cast<int *>(in->MutableData());
int new_size = in->ElementsNum();
if (in0->GetFormat() == in->GetFormat()) {
for (int i = 0; i < new_size; i++) output_shape.push_back(out_shape[i]);
} else {
if ((in0->GetFormat() == schema::Format_NHWC) && (in->GetFormat() == schema::Format_NCHW)) {
output_shape.push_back(out_shape[0]);
output_shape.push_back(out_shape[2]);
output_shape.push_back(out_shape[3]);
output_shape.push_back(out_shape[1]);
} else {
MS_LOG(ERROR) << "Shape covnert is not supported";
return RET_ERROR;
}
}
auto *out = outputs.at(0);
MS_ASSERT(out != nullptr);
out->set_shape(output_shape);
out->set_data_type(in0->data_type());
out->SetFormat(in0->GetFormat());
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -0,0 +1,79 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_OPS_GROUP_CONV2D_GRAD_INPUT_H_
#define MINDSPORE_LITE_SRC_OPS_GROUP_CONV2D_GRAD_INPUT_H_
#include <vector>
#include <set>
#include <cmath>
#include <memory>
#include <string>
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class GroupConv2DGradInput : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(GroupConv2DGradInput, PrimitiveC);
GroupConv2DGradInput() = default;
explicit GroupConv2DGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetFormat(int format);
void SetGroup(int group);
void SetChannelIn(int channel_in);
void SetChannelOut(int channel_out);
void SetKernelW(int kernel_w);
void SetKernelH(int kernel_h);
void SetStrideW(int stride_w);
void SetStrideH(int stride_h);
void SetPadMode(int pad_mode);
void SetPadUp(int pad_up);
void SetPadDown(int pad_down);
void SetPadLeft(int pad_left);
void SetPadRight(int pad_right);
void SetDilateW(int dilate_w);
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
#else
GroupConv2DGradInput() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
int GetFormat() const;
int GetGroup() const;
int GetChannelIn() const;
int GetChannelOut() const;
int GetKernelW() const;
int GetKernelH() const;
int GetStrideW() const;
int GetStrideH() const;
int GetPadMode() const;
int GetPadUp() const;
int GetPadDown() const;
int GetPadLeft() const;
int GetPadRight() const;
int GetDilateW() const;
int GetDilateH() const;
bool GetHasBias() const;
int GetActivationType() const;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_OPS_GROUP_CONV2D_GRAD_INPUT_H_

@ -18,7 +18,31 @@
namespace mindspore {
namespace lite {
#ifndef PRIMITIVE_WRITEABLE
#ifdef PRIMITIVE_WRITEABLE
int Neg::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Neg;
}
if (this->primitive_->value.type != schema::PrimitiveType_Neg) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
this->primitive_->value.value = new (std::nothrow) schema::NegT();
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int Neg::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(primitive != nullptr);
MS_ASSERT(fbb != nullptr);

@ -31,6 +31,7 @@ class Neg : public ArithmeticSelf {
MS_DECLARE_PARENT(Neg, ArithmeticSelf);
Neg() = default;
explicit Neg(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
Neg() = default;

@ -23,6 +23,37 @@ int OneHot::GetAxis() const { return this->primitive_->value.AsOneHot()->axis; }
void OneHot::SetAxis(int axis) { this->primitive_->value.AsOneHot()->axis = axis; }
int OneHot::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_OneHot;
}
if (this->primitive_->value.type != schema::PrimitiveType_OneHot) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::OneHotT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
attr->axis = -1;
if (prim.GetAttr("axis") != nullptr) {
attr->axis = GetValue<int>(prim.GetAttr("axis"));
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int OneHot::GetAxis() const { return this->primitive_->value_as_OneHot()->axis(); }

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save