tod networks test ci

pull/8478/head
yoni 4 years ago
parent a3066105d5
commit 0512d58135

@ -219,6 +219,9 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/internal)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nnacl)
if (ENABLE_TOOLS)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/benchmark)
if (SUPPORT_TRAIN)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/net_train)
endif()
endif()
if (NOT WIN32)
if (ENABLE_TOOLS)

@ -18,32 +18,36 @@
#include <vector>
#include "include/model.h"
namespace mindspore::lite {
namespace mindspore {
namespace lite {
/// \brief TrainModel Defines a class that allows to import and export a mindsport trainable model
struct TrainModel : public lite::Model {
/// \brief Static method to create a TrainModel pointer.
///
/// \param[in] model_buf Define the buffer read from a model file.
/// \param[in] size Define bytes number of model buffer.
/// \brief Static method to create a TrainModel object
///
/// \return Pointer of MindSpore Lite TrainModel.
/// \param[in] model_buf A buffer that was read from a MS model file
/// \param[in] size Length of the buffer
//
/// \return Pointer to MindSpore Lite TrainModel
static TrainModel *Import(const char *model_buf, size_t size);
/// \brief Free meta graph temporary buffer
/// \brief Free meta graph related data
void Free() override;
/// \brief TrainModel destruct, free all memory
/// \brief Class destructor, free all memory
virtual ~TrainModel();
/// \brief Export Model into buf.
/// \brief Export Model into a buffer
///
/// \param[in] buf Define the buffer to Export into. If nullptr, buf will be allocated
/// \param[in] len size of the buffer.
/// \param[in] buf The buffer to Export into. If equal to nullptr, buf will be allocated
/// \param[in,out] len Size of the pre-allocated buffer, and returned size of the exported buffer
///
/// \return Pointer to buffer with exported model
char* ExportBuf(char* buf, size_t* len) const;
char *ExportBuf(char *buf, size_t *len) const;
size_t buf_size_;
};
} // namespace mindspore::lite
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_TRAIN_MODEL_H_

@ -25,16 +25,59 @@
namespace mindspore {
namespace session {
/// \brief TrainSession Defines a class that allows training a MindSpore model
class TrainSession : public session::LiteSession {
public:
/// \brief Class destructor
virtual ~TrainSession() = default;
/// \brief Static method to create a TrainSession object
///
/// \param[in] context Defines the context of the session to be created
///
/// \return Pointer of MindSpore Lite TrainSession
static TrainSession *CreateSession(lite::Context *context);
/// \brief Compile MindSpore Lite train model
///
/// \note CompileTrainGraph should be called before RunGraph
///
/// \param[in] model Define the model to be compiled
///
/// \return STATUS as an error code of compiling graph, STATUS is defined in errorcode.h
virtual int CompileTrainGraph(lite::TrainModel *model) = 0;
/// \brief Export the trained model into a buffer
///
/// \param[in] buf The buffer to Export into. If equal to nullptr, buf will be allocated
/// \param[in,out] len Size of the pre-allocated buffer, and returned size of the exported buffer
///
/// \return pointer to the export buffer
virtual void *ExportToBuf(char *buf, size_t *len) const = 0;
virtual void Train() = 0;
/// \brief Save the trained model into a flatbuffer file
///
/// \param[in] filename Filename to save flatbuffer to
///
/// \return 0 on success or -1 in case of error
virtual int SaveToFile(const std::string &filename) const = 0;
/// \brief Set model to train mode
/// \return STATUS as an error code of compiling graph, STATUS is defined in errorcode.h
virtual int Train() = 0;
/// \brief Check mode of model
///
/// \return boolean indication if model is in train mode
bool IsTrain() { return train_mode_ == true; }
virtual void Eval() = 0;
/// \brief Set model to eval mode
/// \return STATUS as an error code of compiling graph, STATUS is defined in errorcode.h
virtual int Eval() = 0;
/// \brief Check mode of model
///
/// \return boolean indication if model is in eval mode
bool IsEval() { return train_mode_ == false; }
protected:

@ -270,11 +270,13 @@ if (BUILD_MINDDATA STREQUAL "full")
${CORE_DIR}/utils/ms_utils.cc
)
find_package(Threads REQUIRED)
target_link_libraries(minddata-lite
securec
jpeg-turbo
jpeg
mindspore::json
Threads::Threads
)
# ref: https://github.com/android/ndk/issues/1202

@ -55,20 +55,30 @@ void FusedBatchNormFp32(const void *input, const void *scale, const void *offset
void FusedBatchNormFp32MeanVar(const float *input, float *run_mean, float *run_var, BatchNormParameter *param,
float *save_mean, float *save_var) {
float N = (float)param->unit_;
const float N = (float)param->unit_;
const float VN = N;
const float VNUB = (N > 1.0f) ? (N - 1.0f) : 1.0f;
const float momentum = (1.0f - param->momentum_);
for (int i = 0; i < param->unit_; i++) {
for (int c = 0; c < param->channel_; c++) {
int idx = i * param->channel_ + c;
run_mean[c] += input[idx];
run_var[c] += input[idx] * input[idx];
}
}
const float VN = (N > 1.0f) ? (N - 1.0f) : 1.0f;
for (int c = 0; c < param->channel_; c++) {
run_mean[c] = run_mean[c] / N;
run_var[c] = run_var[c] / VN - run_mean[c] * run_mean[c];
save_mean[c] = param->momentum_ * save_mean[c] + (1 - param->momentum_) * run_mean[c];
const float var = run_var[c];
save_var[c] = param->momentum_ * save_var[c] + (1 - param->momentum_) * var;
run_mean[c] /= N;
}
for (int i = 0; i < param->unit_; i++) {
for (int c = 0; c < param->channel_; c++) {
int idx = i * param->channel_ + c;
run_var[c] += (input[idx] - run_mean[c]) * (input[idx] - run_mean[c]);
}
}
for (int c = 0; c < param->channel_; c++) {
float unbiased_var = (run_var[c] / VNUB);
run_var[c] = (run_var[c] / VN);
save_mean[c] = momentum * save_mean[c] + (1.0f - momentum) * run_mean[c];
save_var[c] = momentum * save_var[c] + (1.0f - momentum) * unbiased_var;
}
}

@ -72,7 +72,7 @@ int HSwishGrad(float *src0, float *src1, int length, float *dst) {
int HSigmoidGrad(float *src0, float *src1, int length, float *dst) {
for (int i = 0; i < length; ++i) {
float tmp = (src1[i] > 3.0f ? 1.0f : (src1[i] < -3.0f ? 0.0f : 1.0f / 6.0f));
float tmp = (src1[i] > 3.0f ? 0.0f : (src1[i] < -3.0f ? 0.0f : 1.0f / 6.0f));
dst[i] = tmp * src0[i];
}
return NNACL_OK;

@ -15,6 +15,8 @@
*/
#include "nnacl/fp32_grad/arithmetic_grad.h"
#include <string.h>
#include "nnacl/fp32_grad/utils.h"
void ElementDivNegSquare(const float *nom, const float *denom, float *output, int element_size) {
for (int i = 0; i < element_size; i++) {
@ -27,3 +29,103 @@ void ElementMulAndDivNegSquare(const float *a, const float *b, const float *deno
output[i] = -a[i] * b[i] / (denom[i] * denom[i]);
}
}
void MaximumByAxes(const float *input0, const float *input1, const float *dy, const int *input0_dims,
const int *input1_dims, const int *dy_dims, float *output0, float *output1, int num_dims) {
int num_output0 = 1;
int num_output1 = 1;
int same_shape = 1;
for (int idx = 0; idx < num_dims; ++idx) {
num_output0 *= input0_dims[idx];
num_output1 *= input1_dims[idx];
if (input0_dims[idx] != input1_dims[idx]) {
same_shape = 0;
}
}
if (same_shape) {
int input_iter[8] = {0};
// Iterate through input_data.
do {
size_t offset = GetInputOffset(num_dims, input0_dims, input_iter);
output0[offset] = input0[offset] > input1[offset] ? dy[offset] : 0.;
output1[offset] = input1[offset] >= input0[offset] ? dy[offset] : 0.;
} while (NextIndex(num_dims, input0_dims, input_iter));
} else {
memset(output0, 0, num_output0 * sizeof(float)); // zero output
memset(output1, 0, num_output1 * sizeof(float)); // zero output
int input_iter[8] = {0};
int axes0[5] = {0};
int axes1[5] = {0};
int num_axes0 = 0;
int num_axes1 = 0;
for (int i = 0; i < num_dims; i++) {
if (input0_dims[i] == 1) {
axes0[num_axes0++] = i;
}
if (input1_dims[i] == 1) {
axes1[num_axes1++] = i;
}
}
do {
size_t offset0 = GetOutputOffset(num_dims, input0_dims, input_iter, num_axes0, axes0);
size_t offset1 = GetOutputOffset(num_dims, input1_dims, input_iter, num_axes1, axes1);
size_t yt_offset = GetInputOffset(num_dims, input0_dims, input_iter);
output0[offset0] += input0[offset0] > input1[offset1] ? dy[yt_offset] : 0.;
output1[offset1] += input1[offset1] >= input0[offset0] ? dy[yt_offset] : 0.;
} while (NextIndex(num_dims, dy_dims, input_iter));
}
}
void MinimumByAxes(const float *input0, const float *input1, const float *dy, const int *input0_dims,
const int *input1_dims, const int *dy_dims, float *output0, float *output1, int num_dims) {
int num_output0 = 1;
int num_output1 = 1;
int same_shape = 1;
for (int idx = 0; idx < num_dims; ++idx) {
num_output0 *= input0_dims[idx];
num_output1 *= input1_dims[idx];
if (input0_dims[idx] != input1_dims[idx]) {
same_shape = 0;
}
}
if (same_shape) {
int input_iter[8] = {0};
// Iterate through input_data.
do {
size_t offset = GetInputOffset(num_dims, input0_dims, input_iter);
output0[offset] = input0[offset] < input1[offset] ? dy[offset] : 0.;
output1[offset] = input1[offset] <= input0[offset] ? dy[offset] : 0.;
} while (NextIndex(num_dims, input0_dims, input_iter));
} else {
memset(output0, 0, num_output0 * sizeof(float)); // zero output
memset(output1, 0, num_output1 * sizeof(float)); // zero output
int input_iter[8] = {0};
int axes0[5] = {0};
int axes1[5] = {0};
int num_axes0 = 0;
int num_axes1 = 0;
for (int i = 0; i < num_dims; i++) {
if (input0_dims[i] == 1) {
axes0[num_axes0++] = i;
}
if (input1_dims[i] == 1) {
axes1[num_axes1++] = i;
}
}
do {
size_t offset0 = GetOutputOffset(num_dims, input0_dims, input_iter, num_axes0, axes0);
size_t offset1 = GetOutputOffset(num_dims, input1_dims, input_iter, num_axes1, axes1);
size_t yt_offset = GetInputOffset(num_dims, input0_dims, input_iter);
output0[offset0] += input0[offset0] < input1[offset1] ? dy[yt_offset] : 0.;
output1[offset1] += input1[offset1] <= input0[offset0] ? dy[yt_offset] : 0.;
} while (NextIndex(num_dims, dy_dims, input_iter));
}
}

@ -16,11 +16,17 @@
#ifndef MINDSPORE_LITE_NNACL_FP32_GRAD_ARITHMETIC_GRAD_H_
#define MINDSPORE_LITE_NNACL_FP32_GRAD_ARITHMETIC_GRAD_H_
#include "nnacl/op_base.h"
#ifdef __cplusplus
extern "C" {
#endif
void ElementDivNegSquare(const float *nom, const float *denom, float *output, int element_size);
void ElementMulAndDivNegSquare(const float *a, const float *b, const float *denom, float *output, int element_size);
void MaximumByAxes(const float *input0, const float *input1, const float *dy, const int *input0_dims,
const int *input1_dims, const int *dy_dims, float *output0, float *output1, int num_dims);
void MinimumByAxes(const float *input0, const float *input1, const float *dy, const int *input0_dims,
const int *input1_dims, const int *dy_dims, float *output0, float *output1, int num_dims);
#ifdef __cplusplus
}
#endif

@ -17,66 +17,55 @@
#include <string.h>
#include "nnacl/fp32_grad/batch_norm.h"
void sumSpatialBatch(const float *in, int size, int ch, float *out) {
void sumSpatialBatch(const float *in, size_t size, int ch, float *out) {
memset(out, 0, ch * sizeof(float));
for (int i = 0; i < size; i++) {
const float *ptr = in + i * ch;
for (int c = 0; c < ch; c++) {
for (size_t i = 0; i < size; i++) {
const float *ptr = in + (i * ch);
for (size_t c = 0; c < ch; c++) {
out[c] += ptr[c];
}
}
}
static void meanVar(const float *in, int size, int ch, float eps, float *mean, float *invar) {
float N = (float)(size);
sumSpatialBatch(in, N, ch, mean);
for (int f = 0; f < ch; ++f) {
mean[f] /= N;
}
for (int f = 0; f < ch; f++) {
float tvar = 0;
for (int i = 0; i < N; i++) {
float x = in[i * ch + f];
tvar += (x - mean[f]) * (x - mean[f]);
}
invar[f] = 1.0f / (sqrt(tvar / N + eps));
}
}
void backwardX(const float *in, const float *dout, const float *scale, const int size, int channels, float eps,
float *mean, float *invar, float *dxhathat_sum, float *dxhat_sum, float *out) {
meanVar(in, size, channels, eps, mean, invar);
for (int i = 0; i < size; i++) {
for (int f = 0; f < channels; f++) {
int ix = i * channels + f;
void backwardX(const float *in, const float *dout, const float *scale, const size_t size, int channels, float *mean,
float *invar, float *dxhathat_sum, float *dxhat_sum, float *out) {
const float N = (size);
for (size_t i = 0; i < size; i++) {
for (size_t f = 0; f < channels; f++) {
size_t ix = i * channels + f;
float x_hat = (in[ix] - mean[f]) * invar[f];
float dxhat = dout[ix] * scale[f];
dxhat_sum[f] += dxhat;
dxhathat_sum[f] += dxhat * x_hat;
float dx_hat = dout[ix] * scale[f];
dxhat_sum[f] += dx_hat;
dxhathat_sum[f] += dx_hat * x_hat;
}
}
for (int i = 0; i < size; i++) {
for (int f = 0; f < channels; f++) {
int ix = i * channels + f;
for (size_t i = 0; i < size; i++) {
for (size_t f = 0; f < channels; f++) {
size_t ix = i * channels + f;
float x_hat = (in[ix] - mean[f]) * invar[f];
float dxhat = dout[ix] * scale[f];
out[ix] = 1.f / size * invar[f] * (size * dxhat - dxhat_sum[f] - x_hat * dxhathat_sum[f]);
float dx_hat = dout[ix] * scale[f];
out[ix] = 1.0f / N * (invar[f]) * (N * dx_hat - dxhat_sum[f] - x_hat * dxhathat_sum[f]);
}
}
}
void backwardScale(const float *x, const float *mean, const float *invar, const float *delta, int batch,
int n, int size, float *scale_updates) {
int i, b, f;
void backwardScale(const float *x, const float *mean, const float *invar, const float *delta, int batch, int n,
int size, float *scale_updates) {
size_t i, b, f;
memset(scale_updates, 0, n * sizeof(float));
for (b = 0; b < batch; ++b) {
for (i = 0; i < size; ++i) {
for (f = 0; f < n; ++f) {
int index = (b * size + i) * n + f;
float x_norm = (x[index] - mean[f]) * invar[f];
scale_updates[f] += delta[index] * x_norm;
scale_updates[f] += (delta[index] * x_norm);
}
}
}
}
void var2Invar(float *save_var, size_t size, float eps) {
for (size_t i = 0; i < size; i++) {
save_var[i] = 1.0f / sqrt(save_var[i] + eps);
}
}

@ -29,11 +29,12 @@ typedef struct BNGradParameter {
extern "C" {
#endif
void sumSpatialBatch(const float *in, int size, int ch, float *out);
void backwardX(const float *in, const float *dout, const float *scale, const int size, int channels, float eps,
float *mean, float *invar, float *xhat_sum, float *dxhat_sum, float *out);
void backwardScale(const float *x, const float *mean, const float *invar, const float *delta, int batch,
int n, int size, float *scale_updates);
void sumSpatialBatch(const float *in, size_t size, int ch, float *out);
void backwardX(const float *in, const float *dout, const float *scale, const size_t size, int channels, float *mean,
float *invar, float *xhat_sum, float *dxhat_sum, float *out);
void backwardScale(const float *x, const float *mean, const float *invar, const float *delta, int batch, int n,
int size, float *scale_updates);
void var2Invar(float *save_var, size_t size, float eps);
#ifdef __cplusplus
}

@ -0,0 +1,23 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp32_grad/dropout_grad.h"
void DropoutGrad(const float *yt_ptr, const float *mask, float *output_ptr, int length, float scale) {
for (int i = 0; i < length; i++) {
output_ptr[i] = yt_ptr[i] * mask[i] * scale;
}
}

@ -0,0 +1,31 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP32_GRAD_DROPOUT_GRAD_H_
#define MINDSPORE_LITE_NNACL_FP32_GRAD_DROPOUT_GRAD_H_
#include "nnacl/op_base.h"
#ifdef __cplusplus
extern "C" {
#endif
void DropoutGrad(const float *yt_ptr, const float *mask, float *output_ptr, int length, float ratio);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_FP32_GRAD_DROPOUT_GRAD_H_

@ -0,0 +1,27 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP32_GRAD_DROPOUT_PARAMETER_H_
#define MINDSPORE_LITE_NNACL_FP32_GRAD_DROPOUT_PARAMETER_H_
#include "nnacl/op_base.h"
typedef struct DropoutParameter {
OpParameter op_parameter_;
float ratio_;
} DropoutParameter;
#endif // MINDSPORE_LITE_NNACL_FP32_GRAD_DROPOUT_PARAMETER_H_

File diff suppressed because it is too large Load Diff

@ -17,11 +17,26 @@
#ifndef MINDSPORE_LITE_NNACL_FP32_GRAD_GEMM_H_
#define MINDSPORE_LITE_NNACL_FP32_GRAD_GEMM_H_
#include <stdlib.h>
#include "nnacl/op_base.h"
#ifdef __cplusplus
extern "C" {
#endif
void gemm(int transpose_a, int transpose_b, int M, int N, int K, float alpha, float *mat_a, int lda, float *mat_b,
int ldb, float beta, float *mat_c, int ldc);
typedef struct {
int ca;
int cb;
ActType atype;
float *bias;
float *mat_a;
float *mat_b;
} GemmCb;
void GemmMatmulPlus(int ta, int tb, int M, int N, int K, float alpha, const float *mat_a, int lda, const float *mat_b,
int ldb, float beta, float *mat_c, int ldc, float *workspace, GemmCb *cb);
void GemmMatmul(int ta, int tb, int M, int N, int K, float alpha, const float *mat_a, int lda, const float *mat_b,
int ldb, float beta, float *mat_c, int ldc, float *workspace);
int MatSize(int row, int col, int round);
int MatSizeTotal(int row, int col, int deep, int inc);
#ifdef __cplusplus
}
#endif

@ -16,10 +16,11 @@
#include <string.h>
#include "nnacl/fp32_grad/pack_ext.h"
#include "nnacl/pack.h"
static int is_a_ge_zero_and_a_lt_b(int a, int b) { return (unsigned)(a) < (unsigned)(b); }
void im2col_hwc(const float *in_data, float *data_col, ConvParameter *conv_param) {
void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParameter *conv_param, int rows, int start) {
const int pad_left = conv_param->pad_l_;
const int pad_up = conv_param->pad_u_;
@ -35,42 +36,42 @@ void im2col_hwc(const float *in_data, float *data_col, ConvParameter *conv_param
const int in_height = conv_param->input_h_;
const int in_width = conv_param->input_w_;
const int output_h = conv_param->output_h_;
const int output_w = conv_param->output_w_;
const int channels = conv_param->input_channel_ / conv_param->group_;
const int tot_channels = conv_param->input_channel_;
int kernel_row, kernel_col, output_rows, output_col;
int row_stride_offset = 0;
int kernel_row, kernel_col;
for (output_rows = output_h; output_rows; output_rows--) {
int col_stride_offset = 0;
for (output_col = output_w; output_col; output_col--) {
for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset;
for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset;
for (int i = 0; i < rows; i++) {
int block_start = start + i;
int input_h = block_start / output_w * stride_h;
int input_w = block_start % output_w * stride_w;
for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
int input_row = -pad_up + kernel_row * dilation_h + input_h;
for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
int input_col = -pad_left + kernel_col * dilation_w + input_w;
if (is_a_ge_zero_and_a_lt_b(input_row, in_height) && is_a_ge_zero_and_a_lt_b(input_col, in_width)) {
const int offset = (input_row * in_width + input_col) * tot_channels;
memcpy(data_col, in_data + offset, sizeof(float) * channels);
data_col += channels;
} else {
memset(data_col, 0, sizeof(float) * channels);
data_col += channels;
}
if (is_a_ge_zero_and_a_lt_b(input_row, in_height) && is_a_ge_zero_and_a_lt_b(input_col, in_width)) {
const int offset = (input_row * in_width + input_col) * tot_channels;
memcpy(data_col, in_data + offset, sizeof(float) * channels);
data_col += channels;
} else {
memset(data_col, 0, sizeof(float) * channels);
data_col += channels;
}
}
col_stride_offset += stride_w;
}
row_stride_offset += stride_h;
}
}
void RollingIm2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input,
int real_cal_num, int block_index) {
rolling_im2col_hwc(input_data, packed_input, conv_param, real_cal_num, block_index);
}
// output matrix is (kernel_h*kernel_w*channels)X(output_h*output_w)
void im2row_hwc(const float *in_data, float *data_row, ConvParameter *conv_param, bool transpose) {
void im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, bool transpose) {
const int pad_left = conv_param->pad_l_;
const int pad_up = conv_param->pad_u_;
@ -150,7 +151,56 @@ void im2row_hwc(const float *in_data, float *data_row, ConvParameter *conv_param
}
}
void col2im_hwc(const float *data_col, float *data_im, ConvParameter *conv_param) {
void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, int rows, int start) {
const int pad_left = conv_param->pad_l_;
const int pad_up = conv_param->pad_u_;
const int stride_h = conv_param->stride_h_;
const int stride_w = conv_param->stride_w_;
const int dilation_h = conv_param->dilation_h_;
const int dilation_w = conv_param->dilation_w_;
const int kernel_h = conv_param->kernel_h_;
const int kernel_w = conv_param->kernel_w_;
const int in_height = conv_param->output_h_;
const int in_width = conv_param->output_w_;
const int output_w = conv_param->input_w_;
const int tot_channels = conv_param->output_channel_;
const int channels = tot_channels / conv_param->group_;
int channel, kernel_row, kernel_col, output_rows, output_col;
for (channel = 0; channel < channels; channel++) {
for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
for (output_rows = start; output_rows < start + rows; output_rows++) {
int input_row = -pad_up + kernel_row * dilation_h + output_rows * stride_h;
if (!is_a_ge_zero_and_a_lt_b(input_row, in_height)) {
for (output_col = output_w; output_col; output_col--) {
*(data_row++) = 0;
}
} else {
int input_col = -pad_left + kernel_col * dilation_w;
for (output_col = output_w; output_col; output_col--) {
if (is_a_ge_zero_and_a_lt_b(input_col, in_width)) {
const int offset = (input_row * in_width + input_col) * tot_channels + channel;
*(data_row++) = in_data[offset];
} else {
*(data_row++) = 0;
}
input_col += stride_w;
}
}
// input_row += stride_h;
}
}
}
}
}
void col2im_hwc(const float *data_col, float *data_im, const ConvParameter *conv_param) {
const int pad_left = conv_param->pad_l_;
const int pad_up = conv_param->pad_u_;
@ -198,3 +248,52 @@ void col2im_hwc(const float *data_col, float *data_im, ConvParameter *conv_param
row_stride_offset += stride_h;
}
}
void rolling_col2im_hwc(const float *data_col, float *data_im, const ConvParameter *conv_param, int rows, int start) {
const int pad_left = conv_param->pad_l_;
const int pad_up = conv_param->pad_u_;
const int stride_h = conv_param->stride_h_;
const int stride_w = conv_param->stride_w_;
const int dilation_h = conv_param->dilation_h_;
const int dilation_w = conv_param->dilation_w_;
const int kernel_h = conv_param->kernel_h_;
const int kernel_w = conv_param->kernel_w_;
const int in_height = conv_param->input_h_;
const int in_width = conv_param->input_w_;
const int output_w = conv_param->output_w_;
const int channels = conv_param->input_channel_ / conv_param->group_;
const int tot_channels = conv_param->input_channel_;
int kernel_row, kernel_col;
for (int r = 0; r < rows; r++) {
int output_col = (start + r) % output_w;
int output_row = (start + r) / output_w;
int row_stride_offset = output_row * stride_h;
int col_stride_offset = output_col * stride_w;
// for (output_col = 0; output_col < output_w; output_col++)
{
for (kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
int input_row = -pad_up + kernel_row * dilation_h + row_stride_offset;
for (kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset;
if (is_a_ge_zero_and_a_lt_b(input_row, in_height) && is_a_ge_zero_and_a_lt_b(input_col, in_width)) {
int offset = (input_row * in_width + input_col) * tot_channels;
float *data_im_ptr = &data_im[offset];
for (int i = 0; i < channels; i++) {
data_im_ptr[i] += data_col[i];
}
}
data_col += channels;
}
}
}
}
}

@ -17,14 +17,18 @@
#ifndef MINDSPORE_LITE_NNACL_FP32_GRAD_PACK_EXT_H_
#define MINDSPORE_LITE_NNACL_FP32_GRAD_PACK_EXT_H_
#include <stddef.h>
#include "nnacl/conv_parameter.h"
#ifdef __cplusplus
extern "C" {
#endif
void im2col_hwc(const float *in_data, float *data_col, ConvParameter *conv_param);
void im2row_hwc(const float *in_data, float *data_row, ConvParameter *conv_param, bool transpose);
void col2im_hwc(const float *data_col, float *data_im, ConvParameter *conv_param);
void RollingIm2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input,
int real_cal_num, int block_index);
void rolling_im2col_hwc(const float *in_data, float *data_col, const ConvParameter *conv_param, int rows, int start);
void rolling_im2row_hwc(const float *in_data, float *data_row, const ConvParameter *conv_param, int rows, int start);
void rolling_col2im_hwc(const float *data_col, float *data_im, const ConvParameter *conv_param, int rows, int start);
#ifdef __cplusplus
}
#endif

@ -14,6 +14,7 @@
* limitations under the License.
*/
#include <stdint.h>
#include <string.h>
#include <float.h>
#include "nnacl/fp32_grad/pooling_grad.h"
@ -31,8 +32,7 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, PoolingParameter
int output_h = pooling_param->output_h_;
int output_batch = pooling_param->output_batch_;
for (int i = 0; i < in_h * in_w * channel * output_batch; i++) output_ptr[i] = 0.0;
memset(output_ptr, 0, in_h * in_w * channel * output_batch * sizeof(float));
float kk = (float)(win_h * win_w);
for (uint16_t ib = 0; ib < output_batch; ib++) {
float *out = &output_ptr[(ib * in_h * in_w * channel)];
@ -77,8 +77,7 @@ void MaxPoolingGrad(const float *input_ptr, const float *dx_ptr, const float *dy
int output_h = pooling_param->output_h_;
int output_batch = pooling_param->output_batch_;
for (int i = 0; i < in_h * in_w * channel * output_batch; i++) output_ptr[i] = 0.0;
memset(output_ptr, 0, in_h * in_w * channel * output_batch * sizeof(float));
for (uint16_t ib = 0; ib < output_batch; ib++) {
float *out = &output_ptr[(ib * in_h * in_w * channel)];
const float *inPtr = (const float *)(&input_ptr[(ib * in_h * in_w * channel)]);

@ -15,50 +15,7 @@
*/
#include <string.h>
#include "nnacl/fp32_grad/reduce_grad.h"
static inline int NextIndex(const int num_dims, const int *dims, int *current) {
int carry = 1;
for (int idx = num_dims - 1; idx >= 0; --idx) {
int current_val = current[idx] + carry;
if (dims[idx] == current_val) {
current[idx] = 0;
} else {
current[idx] = current_val;
carry = 0;
break;
}
}
return (carry == 0);
}
static inline size_t GetInputOffset(const int num_dims, const int *dims, const int *iter) {
size_t offset = 0;
for (int idx = 0; idx < num_dims; ++idx) {
offset = offset * (size_t)(dims[idx]) + (size_t)(iter[idx]);
}
return offset;
}
static inline size_t GetOutputOffset(const int num_dims, const int *dims, const int *iter, const int num_axis,
const int *axes) {
size_t offset = 0;
for (int idx = 0; idx < num_dims; ++idx) {
// if we need to skip this axis
int is_axis = 0;
for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
if (idx == axes[axis_idx]) {
is_axis = 1;
break;
}
}
if (!is_axis) {
offset = offset * (size_t)(dims[idx]) + (size_t)(iter[idx]);
}
}
return offset;
}
#include "nnacl/fp32_grad/utils.h"
void ReduceMeanByAxes(const float *input_data, int *input_iter, const int *input_dims, int input_num_dims,
const int *axes, int num_axes, float *output_data, const int *output_dims, int output_num_dims) {
@ -111,7 +68,7 @@ void ReduceSumByAxes(const float *input, const int *input_dims, float *output, c
return;
}
for (int idx = 0; idx < num_outputs; ++idx) output[idx] = 0; // zero output
memset(output, 0, num_outputs * sizeof(float)); // zero output
int input_iter[8] = {0};
int axes[5] = {0};

@ -41,7 +41,6 @@ void SoftmaxGrad(const float *input_ptr, const float *yt_ptr, float *output_ptr,
const int M = input_shape[axis];
const int N = inner_size;
const int K = 1;
for (int i = 0; i < outter_size; i++) {
int outter_offset = i * dim;
memset(sum_data, 0.0f, inner_size * sizeof(float));
@ -52,7 +51,14 @@ void SoftmaxGrad(const float *input_ptr, const float *yt_ptr, float *output_ptr,
sum_data[k] += output_ptr[offset] * input_ptr[offset];
}
}
gemm(0, 0, M, N, K, -1, sum_mul, K, sum_data, N, 1, &output_ptr[outter_offset], N);
for (int k = 0; k < M; ++k) {
float a = -sum_mul[k];
for (int j = 0; j < N; ++j) {
*(output_ptr + outter_offset + k * N + j) += a * sum_data[j];
}
}
// gemm(0, 0, M, N, K, -1, sum_mul, K, sum_data, N, 1, &output_ptr[outter_offset], N);
}
for (int i = 0; i < ele_size; i++) {

@ -0,0 +1,72 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP32_GRAD_UTILS_H_
#define MINDSPORE_LITE_NNACL_FP32_GRAD_UTILS_H_
#include "nnacl/op_base.h"
#ifdef __cplusplus
extern "C" {
#endif
static inline size_t GetInputOffset(int num_dims, const int *dims, const int *iter) {
size_t offset = 0;
for (int idx = 0; idx < num_dims; ++idx) {
offset = offset * (size_t)(dims[idx]) + (size_t)(iter[idx]);
}
return offset;
}
static inline size_t GetOutputOffset(int num_dims, const int *dims, const int *iter, int num_axis, const int *axes) {
size_t offset = 0;
for (int idx = 0; idx < num_dims; ++idx) {
// if we need to skip this axis
int is_axis = 0;
for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
if (idx == axes[axis_idx]) {
is_axis = 1;
break;
}
}
if (is_axis == 0) {
offset = offset * (size_t)(dims[idx]) + (size_t)(iter[idx]);
}
}
return offset;
}
static inline int NextIndex(int num_dims, const int *dims, int *current) {
int carry = 1;
for (int idx = num_dims - 1; idx >= 0; --idx) {
int current_val = current[idx] + carry;
if (dims[idx] == current_val) {
current[idx] = 0;
} else {
current[idx] = current_val;
carry = 0;
break;
}
}
return (carry == 0);
}
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_FP32_GRAD_UTILS_H_

@ -234,6 +234,9 @@ union PrimitiveType {
BinaryCrossEntropyGrad,
BinaryCrossEntropy,
LpNormalization,
DropoutGrad,
MaximumGrad,
MinimumGrad
}
enum QuantType: int {

@ -224,6 +224,7 @@ table Conv2DGradFilter {
dilateW: int;
dilateH: int;
hasBias: bool = false;
filter_shape: [int];
activationType: ActivationType = 0;
}
@ -244,6 +245,7 @@ table Conv2DGradInput {
dilateW: int;
dilateH: int;
hasBias: bool = false;
input_shape: [int];
activationType: ActivationType = 0;
}
@ -264,6 +266,7 @@ table GroupConv2DGradInput {
dilateW: int;
dilateH: int;
hasBias: bool = false;
input_shape: [int];
activationType: ActivationType = 0;
}
@ -478,13 +481,10 @@ table DeConv2DGradFilter {
}
table BNGrad {
eps : float;
momentum: float;
}
table BNGradInput {
eps : float;
eps: float;
momentum: float;
}
table Scale {
axis: int;
activationType: ActivationType = 0;
@ -1087,6 +1087,16 @@ table FftReal {
table FftImag {
}
table DropoutGrad {
ratio : float = 0.5;
}
table MaximumGrad {
}
table MinimumGrad {
}
table NonMaxSuppression {
centerPointBox : int = 0;
}

@ -95,13 +95,23 @@ class LiteKernel {
std::string name() const { return this->name_; }
virtual void train() { train_mode_ = true; }
virtual int Train() {
this->train_mode_ = true;
return mindspore::lite::RET_OK;
}
virtual bool IsTrain() const { return this->train_mode_; }
virtual int Eval() {
this->train_mode_ = false;
return mindspore::lite::RET_OK;
}
virtual bool is_train() { return train_mode_; }
virtual bool IsEval() const { return !this->train_mode_; }
virtual void eval() { train_mode_ = false; }
virtual void SetTrainable(bool trainable = true) { this->trainable_ = trainable; }
virtual bool is_eval() { return !train_mode_; }
virtual bool IsTrainable() const { return this->trainable_; }
void set_name(const std::string &name) { this->name_ = name; }
@ -179,6 +189,7 @@ class LiteKernel {
std::vector<LiteKernel *> in_kernels_;
std::vector<LiteKernel *> out_kernels_;
bool train_mode_ = false;
bool trainable_ = false; // paramaters of this Kernel are trained in Train Session
bool is_model_output_ = false;
size_t workspace_size_ = 0;
static void *workspace_;

@ -73,7 +73,7 @@ Registry AdamRegistry(schema::PrimitiveType_Adam, AdamCreator);
int Adam::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
if (10 != inputs.size()) {
MS_LOG(ERROR) << "Adam should have at 10 input tensors";
MS_LOG(ERROR) << "Adam should have 10 input tensors";
return RET_ERROR;
}

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save