|
|
|
@ -42,7 +42,6 @@ int ArithmeticGradCPUKernel::Init() {
|
|
|
|
|
arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradMul2L;
|
|
|
|
|
else if (Type() == PrimitiveType_DivGrad)
|
|
|
|
|
arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradDiv2L;
|
|
|
|
|
|
|
|
|
|
} else if (dx2->ElementsNum() < dx1->ElementsNum()) {
|
|
|
|
|
if (Type() == PrimitiveType_MulGrad)
|
|
|
|
|
arithmetic_grad_ = &ArithmeticGradCPUKernel::ArithmeticGradMul1L;
|
|
|
|
@ -75,25 +74,28 @@ int ArithmeticGradCPUKernel::Init() {
|
|
|
|
|
|
|
|
|
|
void ArithmeticGradCPUKernel::ArithmeticGradAdd(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2,
|
|
|
|
|
int dx2_size) {
|
|
|
|
|
if (dx1_size == dy_size)
|
|
|
|
|
if (dx1_size == dy_size) {
|
|
|
|
|
memcpy(dx1, dy, dy_size * sizeof(float));
|
|
|
|
|
else
|
|
|
|
|
} else {
|
|
|
|
|
ReduceSumByAxes(dy, arithmeticParameter_->out_shape_, dx1, arithmeticParameter_->in_shape0_,
|
|
|
|
|
arithmeticParameter_->ndim_);
|
|
|
|
|
if (dx2_size == dy_size)
|
|
|
|
|
}
|
|
|
|
|
if (dx2_size == dy_size) {
|
|
|
|
|
memcpy(dx2, dy, dy_size * sizeof(float));
|
|
|
|
|
else
|
|
|
|
|
} else {
|
|
|
|
|
ReduceSumByAxes(dy, arithmeticParameter_->out_shape_, dx2, arithmeticParameter_->in_shape1_,
|
|
|
|
|
arithmeticParameter_->ndim_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ArithmeticGradCPUKernel::ArithmeticGradSub(float *dy, int dy_size, float *dx1, int dx1_size, float *dx2,
|
|
|
|
|
int dx2_size) {
|
|
|
|
|
if (dx1_size == dy_size)
|
|
|
|
|
if (dx1_size == dy_size) {
|
|
|
|
|
memcpy(dx1, dy, dy_size * sizeof(float));
|
|
|
|
|
else
|
|
|
|
|
} else {
|
|
|
|
|
ReduceSumByAxes(dy, arithmeticParameter_->out_shape_, dx1, arithmeticParameter_->in_shape0_,
|
|
|
|
|
arithmeticParameter_->ndim_);
|
|
|
|
|
}
|
|
|
|
|
if (dx2_size == dy_size) {
|
|
|
|
|
for (int i = 0; i < dx2_size; i++) {
|
|
|
|
|
dx2[i] = -dy[i];
|
|
|
|
@ -156,7 +158,9 @@ void ArithmeticGradCPUKernel::ArithmeticGradDiv1L(float *dy, int dy_size, float
|
|
|
|
|
arithmeticParameter_); // broadcast directly to dx1
|
|
|
|
|
ReduceSumByAxes(tile_data2, arithmeticParameter_->in_shape0_, dx2, arithmeticParameter_->in_shape1_,
|
|
|
|
|
arithmeticParameter_->ndim_);
|
|
|
|
|
for (int i = 0; i < dx2_size; i++) dx2[i] = -dx2[i];
|
|
|
|
|
for (int i = 0; i < dx2_size; i++) {
|
|
|
|
|
dx2[i] = -dx2[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// broadcasting x2
|
|
|
|
|
BroadcastDiv(dy, x2_data, tile_data0, tile_data1, dx1, dy_size, arithmeticParameter_); // broadcast directly to dx1
|
|
|
|
@ -214,6 +218,7 @@ int ArithmeticGradCPUKernel::Execute(int task_id) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int ArithmeticGradRun(void *cdata, int task_id) {
|
|
|
|
|
MS_ASSERT(cdata != nullptr);
|
|
|
|
|
auto Arithmetic_kernel = reinterpret_cast<ArithmeticGradCPUKernel *>(cdata);
|
|
|
|
|
auto error_code = Arithmetic_kernel->Execute(task_id);
|
|
|
|
|
if (error_code != RET_OK) {
|
|
|
|
|