add_resize_ops

pull/10198/head
wanyiming 4 years ago
parent 83a29a0a54
commit 2283eac723

@ -826,5 +826,26 @@ std::string GetProcessorStr(const AnfNodePtr &anf_node) {
return processor;
}
float Scaling(size_t in_size, size_t out_size, bool align_corners) {
return (align_corners && out_size > 1) ? (in_size - 1) / static_cast<float>(out_size - 1)
: in_size / static_cast<float>(out_size);
}
float ScaleGrid(const int x, const float scale) { return static_cast<float>(x) * scale; }
void ComputeInterpolationWeights(const size_t out_size, const size_t in_size, const float scale,
CachedInterpolation *interpolation) {
interpolation[out_size].lower = 0;
interpolation[out_size].upper = 0;
for (size_t i = 0; i <= out_size - 1; ++i) {
const float in = ScaleGrid(i, scale);
const float in_f = std::floor(in);
interpolation[i].lower = std::max(static_cast<size_t>(in_f), static_cast<size_t>(0));
interpolation[i].upper = std::min(static_cast<size_t>(std::ceil(in)), in_size - 1);
interpolation[i].lerp = in - in_f;
}
}
} // namespace kernel
} // namespace mindspore

@ -102,6 +102,16 @@ void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector<std::pair<An
bool IsWeightBoundary(const AnfNodePtr &node);
std::vector<int64_t> GetReduceAttrAxis(const CNodePtr &cnode);
std::string GetProcessorStr(const AnfNodePtr &anf_node);
float Scaling(size_t in_size, size_t out_size, bool align_corners);
float ScaleGrid(const int x, const float scale);
struct CachedInterpolation {
size_t lower;
size_t upper;
float lerp;
};
void ComputeInterpolationWeights(const size_t out_size, const size_t in_size, const float scale,
CachedInterpolation *interpolation);
template <typename T>
inline std::string Vector2Str(const std::vector<T> &inputs) {
@ -113,6 +123,14 @@ inline std::string Vector2Str(const std::vector<T> &inputs) {
}
return "";
}
template <typename T>
inline T ComputeLerp(T top_left, T top_right, T bottom_left, T bottom_right, T x_lerp, T y_lerp) {
T top = top_left + (top_right - top_left) * x_lerp;
T bottom = bottom_left + (bottom_right - bottom_left) * x_lerp;
return top + (bottom - top) * y_lerp;
}
} // namespace kernel
} // namespace mindspore

@ -0,0 +1,113 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/resize_bilinear_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "backend/kernel_compiler/common_utils.h"
namespace mindspore {
namespace kernel {
void ResizeBilinearCPUKernel::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node);
shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
size_ = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, SIZE);
align_corners_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "align_corners");
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
size_t in_height = shape_[2];
size_t in_width = shape_[3];
size_t out_height = size_[0];
size_t out_width = size_[1];
height_scale = Scaling(in_height, out_height, align_corners_);
width_scale = Scaling(in_width, out_width, align_corners_);
}
bool ResizeBilinearCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeFloat16) {
LaunchKernel<float16, float>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32) {
LaunchKernel<float, float>(inputs, outputs);
}
return true;
}
template <typename T1, typename T2>
void ResizeBilinearCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
auto input_addr = reinterpret_cast<T1 *>(inputs[0]->addr);
auto output_addr = reinterpret_cast<T2 *>(outputs[0]->addr);
size_t batch_size = shape_[0];
size_t channel = shape_[1];
size_t in_height = shape_[2];
size_t in_width = shape_[3];
size_t out_height = size_[0];
size_t out_width = size_[1];
size_t out_hw_size = out_height * out_width;
size_t in_hw_size = in_height * in_width;
size_t bhwc_size = in_hw_size * channel * batch_size;
if (out_height == in_height && out_width == in_width) {
for (size_t i = 0; i < bhwc_size; ++i) {
output_addr[i] = static_cast<float>(input_addr[i]);
}
}
std::vector<CachedInterpolation> ys(out_height + 1);
std::vector<CachedInterpolation> xs(out_width + 1);
ComputeInterpolationWeights(out_height, in_height, height_scale, ys.data());
ComputeInterpolationWeights(out_width, in_width, width_scale, xs.data());
for (size_t b = 0; b < batch_size; ++b) {
for (size_t c = 0; c < channel; ++c) {
for (size_t h = 0; h < out_height; ++h) {
const T1 *ys_input_lower_ptr = input_addr + ys[h].lower * in_width;
const T1 *ys_input_upper_ptr = input_addr + ys[h].upper * in_width;
const T2 ys_lerp = T2(ys[h].lerp);
for (size_t w = 0; w < out_width; ++w) {
const size_t xs_lower = xs[w].lower;
const size_t xs_upper = xs[w].upper;
const T2 xs_lerp = T2(xs[w].lerp);
const T2 top_left(ys_input_lower_ptr[xs_lower]);
const T2 top_right(ys_input_lower_ptr[xs_upper]);
const T2 bottom_left(ys_input_upper_ptr[xs_lower]);
const T2 bottom_right(ys_input_upper_ptr[xs_upper]);
output_addr[h * out_width + w] =
ComputeLerp(top_left, top_right, bottom_left, bottom_right, xs_lerp, ys_lerp);
}
}
output_addr += out_hw_size;
input_addr += in_hw_size;
}
}
}
void ResizeBilinearCPUKernel::CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(EXCEPTION) << "ResizeBilinear needs 1 inputs, but gets " << input_num;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "ResizeBilinear expects 1 output, but gets" << output_num;
}
}
} // namespace kernel
} // namespace mindspore

@ -0,0 +1,58 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_BILINEAR_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_BILINEAR_CPU_KERNEL_H_
#include <memory>
#include <unordered_map>
#include <vector>
#include <algorithm>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class ResizeBilinearCPUKernel : public CPUKernel {
public:
ResizeBilinearCPUKernel() = default;
~ResizeBilinearCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
template <typename T1, typename T2>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
private:
void CheckParam(const CNodePtr &kernel_node);
TypeId dtype_{kTypeUnknown};
bool align_corners_ = false;
float height_scale;
float width_scale;
std::vector<int64_t> size_;
std::vector<size_t> shape_;
};
MS_REG_CPU_KERNEL(ResizeBilinear, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32),
ResizeBilinearCPUKernel);
MS_REG_CPU_KERNEL(ResizeBilinear, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ResizeBilinearCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_BILINEAR_CPU_KERNEL_H_

@ -0,0 +1,106 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/resize_bilinear_grad_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "backend/kernel_compiler/common_utils.h"
namespace mindspore {
namespace kernel {
void ResizeBilinearGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node);
shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
size_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
align_corners_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "align_corners");
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
size_t in_height = shape_[2];
size_t in_width = shape_[3];
size_t out_height = size_[2];
size_t out_width = size_[3];
height_scale = Scaling(out_height, in_height, align_corners_);
width_scale = Scaling(out_width, in_width, align_corners_);
}
bool ResizeBilinearGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeFloat16) {
LaunchKernel<float16>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32) {
LaunchKernel<float>(inputs, outputs);
}
return true;
}
template <typename T>
void ResizeBilinearGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
auto dloss_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
size_t batch_size = shape_[0];
size_t channel = shape_[1];
size_t in_height = shape_[2];
size_t in_width = shape_[3];
size_t out_height = size_[2];
size_t out_width = size_[3];
size_t out_hw_size = out_height * out_width;
size_t in_hw_size = in_height * in_width;
for (size_t b = 0; b < batch_size; ++b) {
for (size_t c = 0; c < channel; ++c) {
for (size_t h = 0; h < in_height; ++h) {
const float in_y = static_cast<float>(h) * height_scale;
const size_t top_y_index = std::max(static_cast<size_t>(floorf(in_y)), static_cast<size_t>(0));
const size_t bottom_y_index = std::min(static_cast<size_t>(ceilf(in_y)), out_height - 1);
const float y_lerp = in_y - floorf(in_y);
const float inverse_y_lerp = 1.0 - y_lerp;
for (size_t w = 0; w < in_width; ++w) {
const float in_x = static_cast<float>(w) * width_scale;
const size_t left_x_index = std::max(static_cast<size_t>(floorf(in_x)), static_cast<size_t>(0));
const size_t right_x_index = std::min(static_cast<size_t>(ceilf(in_x)), out_width - 1);
const float x_lerp = in_x - floorf(in_x);
const float inverse_x_lerp = 1.0 - x_lerp;
output_addr[top_y_index * out_width + left_x_index] +=
dloss_addr[h * in_width + w] * T(inverse_y_lerp * inverse_x_lerp);
output_addr[top_y_index * out_width + right_x_index] +=
dloss_addr[h * in_width + w] * T(inverse_y_lerp * x_lerp);
output_addr[bottom_y_index * out_width + left_x_index] +=
dloss_addr[h * in_width + w] * T(y_lerp * inverse_x_lerp);
output_addr[bottom_y_index * out_width + right_x_index] += dloss_addr[h * in_width + w] * T(y_lerp * x_lerp);
}
}
output_addr += out_hw_size;
dloss_addr += in_hw_size;
}
}
}
void ResizeBilinearGradCPUKernel::CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(EXCEPTION) << "ResizeBilinearGrad needs 2 inputs, but gets " << input_num;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "ResizeBilinear Gradexpects 1 output, but gets" << output_num;
}
}
} // namespace kernel
} // namespace mindspore

@ -0,0 +1,62 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_BILINEAR_GRAD_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_BILINEAR_GRAD_CPU_KERNEL_H_
#include <memory>
#include <unordered_map>
#include <vector>
#include <algorithm>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class ResizeBilinearGradCPUKernel : public CPUKernel {
public:
ResizeBilinearGradCPUKernel() = default;
~ResizeBilinearGradCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
private:
void CheckParam(const CNodePtr &kernel_node);
TypeId dtype_{kTypeUnknown};
bool align_corners_ = false;
float height_scale;
float width_scale;
std::vector<size_t> size_;
std::vector<size_t> shape_;
};
MS_REG_CPU_KERNEL(
ResizeBilinearGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ResizeBilinearGradCPUKernel);
MS_REG_CPU_KERNEL(
ResizeBilinearGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ResizeBilinearGradCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_BILINEAR_GRAD_CPU_KERNEL_H_

@ -0,0 +1,94 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/resize_nearest_neighbor_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "backend/kernel_compiler/common_utils.h"
namespace mindspore {
namespace kernel {
void ResizeNearestNeighborCPUKernel::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node);
std::vector<size_t> input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
std::vector<int64_t> output_size = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, SIZE);
align_corners_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "align_corners");
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
batch_size_ = input_shape[0];
channel_ = input_shape[1];
in_height_ = input_shape[2];
in_width_ = input_shape[3];
out_height_ = output_size[0];
out_width_ = output_size[1];
height_scale_ = Scaling(in_height_, out_height_, align_corners_);
width_scale_ = Scaling(in_width_, out_width_, align_corners_);
output_size_ = batch_size_ * channel_ * out_height_ * out_width_;
}
bool ResizeNearestNeighborCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeFloat16) {
LaunchKernel<float16>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32) {
LaunchKernel<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeInt32) {
LaunchKernel<int32_t>(inputs, outputs);
}
return true;
}
template <typename T>
void ResizeNearestNeighborCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
auto input_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
if (out_height_ == in_height_ && out_width_ == in_width_) {
for (size_t i = 0; i < output_size_; ++i) {
output_addr[i] = input_addr[i];
}
}
for (size_t i = 0; i < output_size_; ++i) {
size_t pos0 = i / (channel_ * out_height_ * out_width_) % batch_size_;
size_t pos1 = i / (out_height_ * out_width_) % channel_;
size_t pos2 = i / (out_width_) % out_height_;
size_t pos3 = i % out_width_;
const size_t in_y = std::min((align_corners_) ? static_cast<size_t>(roundf(pos2 * height_scale_))
: static_cast<size_t>(floorf(pos2 * height_scale_)),
in_height_ - 1);
const size_t in_x = std::min((align_corners_) ? static_cast<size_t>(roundf(pos3 * width_scale_))
: static_cast<size_t>(floorf(pos3 * width_scale_)),
in_width_ - 1);
size_t input_pos =
pos0 * channel_ * in_height_ * in_width_ + pos1 * in_height_ * in_width_ + in_y * in_width_ + in_x;
output_addr[i] = input_addr[input_pos];
}
}
void ResizeNearestNeighborCPUKernel::CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(EXCEPTION) << "ResizeBilinear needs 1 inputs, but gets " << input_num;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "ResizeBilinear expects 1 output, but gets" << output_num;
}
}
} // namespace kernel
} // namespace mindspore

@ -0,0 +1,68 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_NEAREST_NEIGHBOR_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_NEAREST_NEIGHBOR_CPU_KERNEL_H_
#include <memory>
#include <unordered_map>
#include <vector>
#include <algorithm>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class ResizeNearestNeighborCPUKernel : public CPUKernel {
public:
ResizeNearestNeighborCPUKernel() = default;
~ResizeNearestNeighborCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
private:
void CheckParam(const CNodePtr &kernel_node);
TypeId dtype_{kTypeUnknown};
bool align_corners_{false};
size_t batch_size_{0};
size_t channel_{0};
size_t in_height_{0};
size_t in_width_{0};
size_t out_height_{0};
size_t out_width_{0};
size_t output_size_{0};
float height_scale_{1.0};
float width_scale_{1.0};
};
MS_REG_CPU_KERNEL(ResizeNearestNeighbor,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ResizeNearestNeighborCPUKernel);
MS_REG_CPU_KERNEL(ResizeNearestNeighbor,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ResizeNearestNeighborCPUKernel);
MS_REG_CPU_KERNEL(ResizeNearestNeighbor, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ResizeNearestNeighborCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_NEAREST_NEIGHBOR_CPU_KERNEL_H_

@ -0,0 +1,91 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/resize_nearest_neighbor_grad_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "backend/kernel_compiler/common_utils.h"
namespace mindspore {
namespace kernel {
void ResizeNearestNeighborGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node);
std::vector<size_t> input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
std::vector<size_t> output_size = AnfAlgo::GetOutputInferShape(kernel_node, 0);
align_corners_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "align_corners");
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
batch_size_ = input_shape[0];
channel_ = input_shape[1];
in_height_ = input_shape[2];
in_width_ = input_shape[3];
out_height_ = output_size[2];
out_width_ = output_size[3];
height_scale_ = Scaling(out_height_, in_height_, align_corners_);
width_scale_ = Scaling(out_width_, in_width_, align_corners_);
}
bool ResizeNearestNeighborGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeFloat16) {
LaunchKernel<float16>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32) {
LaunchKernel<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeInt32) {
LaunchKernel<int32_t>(inputs, outputs);
}
return true;
}
template <typename T>
void ResizeNearestNeighborGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
auto dloss_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
size_t in_hw_size = in_width_ * in_height_;
size_t out_hw_size = out_width_ * out_height_;
for (size_t b = 0; b < batch_size_; ++b) {
for (size_t c = 0; c < channel_; ++c) {
for (size_t h = 0; h < in_height_; ++h) {
const size_t out_y = std::min((align_corners_) ? static_cast<size_t>(roundf(h * height_scale_))
: static_cast<size_t>(floorf(h * height_scale_)),
out_height_ - 1);
for (size_t w = 0; w < in_width_; ++w) {
const size_t out_x = std::min((align_corners_) ? static_cast<size_t>(roundf(w * width_scale_))
: static_cast<size_t>(floorf(w * width_scale_)),
out_width_ - 1);
output_addr[out_y * out_width_ + out_x] += dloss_addr[h * in_width_ + w];
}
}
output_addr += out_hw_size;
dloss_addr += in_hw_size;
}
}
}
void ResizeNearestNeighborGradCPUKernel::CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(EXCEPTION) << "ResizeBilinearGrad needs 1 inputs, but gets " << input_num;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "ResizeBilinear Gradexpects 1 output, but gets" << output_num;
}
}
} // namespace kernel
} // namespace mindspore

@ -0,0 +1,68 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_NEAREST_NEIGHBOR_GRAD_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_NEAREST_NEIGHBOR_GRAD_CPU_KERNEL_H_
#include <memory>
#include <unordered_map>
#include <vector>
#include <algorithm>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class ResizeNearestNeighborGradCPUKernel : public CPUKernel {
public:
ResizeNearestNeighborGradCPUKernel() = default;
~ResizeNearestNeighborGradCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
private:
void CheckParam(const CNodePtr &kernel_node);
TypeId dtype_{kTypeUnknown};
bool align_corners_{false};
size_t batch_size_{0};
size_t channel_{0};
size_t in_height_{0};
size_t in_width_{0};
size_t out_height_{0};
size_t out_width_{0};
float height_scale_{1.0};
float width_scale_{1.0};
};
MS_REG_CPU_KERNEL(ResizeNearestNeighborGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ResizeNearestNeighborGradCPUKernel);
MS_REG_CPU_KERNEL(ResizeNearestNeighborGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ResizeNearestNeighborGradCPUKernel);
MS_REG_CPU_KERNEL(ResizeNearestNeighborGrad,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ResizeNearestNeighborGradCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_NEAREST_NEIGHBOR_GRAD_CPU_KERNEL_H_

@ -0,0 +1,83 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops.operations import _grad_ops as G
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class ResizeBilinearGradAlignCornerT(nn.Cell):
def __init__(self):
super(ResizeBilinearGradAlignCornerT, self).__init__()
self.ResizeBilinearGradAlignCornerT = G.ResizeBilinearGrad(
align_corners=True)
def construct(self, dy, size):
return self.ResizeBilinearGradAlignCornerT(dy, size)
class ResizeBilinearGradAlignCornerF(nn.Cell):
def __init__(self):
super(ResizeBilinearGradAlignCornerF, self).__init__()
self.ResizeBilinearGradAlignCornerF = G.ResizeBilinearGrad(align_corners=False)
def construct(self, dy, size):
return self.ResizeBilinearGradAlignCornerF(dy, size)
def test_ResizeBilinearGradAlignCornerT():
dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.float32)
orign_image = np.array(
[[[[1.1, 2.2, 3.2, 2.5], [3.3, 4.4, 5.7, 8.1], [3.3, 4.4, 5.7, 8.1], [3.3, 4.4, 5.7, 8.1]]]]).astype(np.float16)
expect = np.array([[[[1., 0., 0., 2.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[3., 0., 0., 4.]]]]).astype(np.float16)
rnn = ResizeBilinearGradAlignCornerT()
output = rnn(Tensor(dy), Tensor(orign_image))
assert np.all(output.asnumpy() == expect)
orign_image = np.array(
[[[[1.1, 2.2, 3.2, 2.5], [3.3, 4.4, 5.7, 8.1], [3.3, 4.4, 5.7, 8.1], [3.3, 4.4, 5.7, 8.1]]]]).astype(np.float32)
expect = np.array([[[[1., 0., 0., 2.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[3., 0., 0., 4.]]]]).astype(np.float32)
rnn = ResizeBilinearGradAlignCornerT()
output = rnn(Tensor(dy), Tensor(orign_image))
assert np.all(output.asnumpy() == expect)
def test_ResizeBilinearGradAlignCornerF():
dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32)
orign_image = np.array([[[[1.1, 2.2], [3.3, 4.4]]]]).astype(np.float16)
expect = np.array([[[[2.25, 0.75],
[0.75, 4.25]]]]).astype(np.float16)
rnn = ResizeBilinearGradAlignCornerF()
output = rnn(Tensor(dy), Tensor(orign_image))
assert np.all(output.asnumpy() == expect)
orign_image = np.array([[[[1.1, 2.2], [3.3, 4.4]]]]).astype(np.float32)
expect = np.array([[[[2.25, 0.75],
[0.75, 4.25]]]]).astype(np.float32)
rnn = ResizeBilinearGradAlignCornerF()
output = rnn(Tensor(dy), Tensor(orign_image))
assert np.all(output.asnumpy() == expect)

File diff suppressed because it is too large Load Diff

@ -0,0 +1,93 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops.operations import _grad_ops as G
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class ResizeNearestNeighborGradAlignCornerT(nn.Cell):
def __init__(self, size=None):
super(ResizeNearestNeighborGradAlignCornerT, self).__init__()
self.ResizeNearestNeighborGradAlignCornerT = G.ResizeNearestNeighborGrad(
align_corners=True)
self.size = size
def construct(self, dy):
return self.ResizeNearestNeighborGradAlignCornerT(dy, self.size)
class ResizeNearestNeighborGradAlignCornerF(nn.Cell):
def __init__(self, size=None):
super(ResizeNearestNeighborGradAlignCornerF, self).__init__()
self.ResizeNearestNeighborGradAlignCornerF = G.ResizeNearestNeighborGrad(
align_corners=False)
self.size = size
def construct(self, dy):
return self.ResizeNearestNeighborGradAlignCornerF(dy, self.size)
def test_ResizeNearestNeighborGradAlignCornerT():
dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.float32)
size = (4, 4)
expect = np.array(
[[[[1, 0, 0, 2], [0, 0, 0, 0], [0, 0, 0, 0], [3, 0, 0, 4]]]]).astype(np.float32)
rnn = ResizeNearestNeighborGradAlignCornerT(size=size)
output = rnn(Tensor(dy))
assert np.all(output.asnumpy() == expect)
dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.float16)
size = (4, 4)
expect = np.array(
[[[[1, 0, 0, 2], [0, 0, 0, 0], [0, 0, 0, 0], [3, 0, 0, 4]]]]).astype(np.float16)
rnn = ResizeNearestNeighborGradAlignCornerT(size=size)
output = rnn(Tensor(dy))
assert np.all(output.asnumpy() == expect)
dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.int32)
size = (4, 4)
expect = np.array(
[[[[1, 0, 0, 2], [0, 0, 0, 0], [0, 0, 0, 0], [3, 0, 0, 4]]]]).astype(np.int32)
rnn = ResizeNearestNeighborGradAlignCornerT(size=size)
output = rnn(Tensor(dy))
assert np.all(output.asnumpy() == expect)
def test_ResizeNearestNeighborGradAlignCornerF():
dy = np.array(
[[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32)
size = (2, 2)
expect = np.array([[[[4, 0], [0, 4]]]]).astype(np.float32)
rnn = ResizeNearestNeighborGradAlignCornerF(size=size)
output = rnn(Tensor(dy))
assert np.all(output.asnumpy() == expect)
dy = np.array(
[[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16)
size = (2, 2)
expect = np.array([[[[4, 0], [0, 4]]]]).astype(np.float16)
rnn = ResizeNearestNeighborGradAlignCornerF(size=size)
output = rnn(Tensor(dy))
assert np.all(output.asnumpy() == expect)
dy = np.array(
[[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.int32)
size = (2, 2)
expect = np.array([[[[4, 0], [0, 4]]]]).astype(np.int32)
rnn = ResizeNearestNeighborGradAlignCornerF(size=size)
output = rnn(Tensor(dy))
assert np.all(output.asnumpy() == expect)

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save