parent
ca6756b5fe
commit
f679568d86
@ -0,0 +1,33 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherNd,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
GatherNdGpuFwdKernel, float, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherNd,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
GatherNdGpuFwdKernel, half, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
GatherNdGpuFwdKernel, int, int)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
@ -0,0 +1,162 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_GATHERND_GPU_KERNEL_H
|
||||
#define MINDSPORE_GATHERND_GPU_KERNEL_H
|
||||
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/gathernd.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T, typename S>
|
||||
class GatherNdGpuFwdKernel : public GpuKernel {
|
||||
public:
|
||||
GatherNdGpuFwdKernel() : dev_batch_strides_(nullptr), dev_batch_indices_(nullptr) {}
|
||||
~GatherNdGpuFwdKernel() {
|
||||
if (dev_batch_strides_ != nullptr) {
|
||||
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast<void *>(dev_batch_strides_));
|
||||
}
|
||||
if (dev_batch_indices_ != nullptr) {
|
||||
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast<void *>(dev_batch_indices_));
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
VARIABLE_NOT_USED(workspace);
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
S *indices_addr = GetDeviceAddress<S>(inputs, 1);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
GatherNd(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], dev_batch_strides_,
|
||||
dev_batch_indices_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
InitResource();
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 2) {
|
||||
MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherNdGpuFwdKernel needs 2.";
|
||||
}
|
||||
input_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
indices_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
|
||||
Reshape();
|
||||
|
||||
size_t dim_indices_last = dims_[dims_.size() - 1];
|
||||
batch_strides_.resize(dim_indices_last, 0);
|
||||
batch_indices_.resize(dim_indices_last, 0);
|
||||
|
||||
if (dim_indices_last > 0) {
|
||||
batch_strides_[dim_indices_last - 1] = input_shapes_[dim_indices_last - 1];
|
||||
batch_indices_[dim_indices_last - 1] = dims_[1];
|
||||
}
|
||||
for (size_t i = dim_indices_last - 1; i > 0; --i) {
|
||||
batch_strides_[i - 1] = input_shapes_[i - 1];
|
||||
batch_indices_[i - 1] = batch_indices_[i] * input_shapes_[i];
|
||||
}
|
||||
|
||||
size_t strides_len = sizeof(S) * batch_strides_.size();
|
||||
void *dev_batch_strides_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(strides_len);
|
||||
if (dev_batch_strides_work == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to alloc dev_batch_strides_work, size: " << strides_len;
|
||||
}
|
||||
dev_batch_strides_ = static_cast<S *>(dev_batch_strides_work);
|
||||
|
||||
size_t indices_len = sizeof(S) * batch_indices_.size();
|
||||
void *dev_batch_indices_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(indices_len);
|
||||
if (dev_batch_indices_work == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to alloc dev_batch_indices_work, size: " << indices_len;
|
||||
}
|
||||
dev_batch_indices_ = static_cast<S *>(dev_batch_indices_work);
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpy(dev_batch_strides_, &batch_strides_[0], strides_len, cudaMemcpyHostToDevice),
|
||||
"cudaMemcpy failed in GatherNdGpuFwdKernel::Init.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpy(dev_batch_indices_, &batch_indices_[0], indices_len, cudaMemcpyHostToDevice),
|
||||
"cudaMemcpy failed in GatherNdGpuFwdKernel::Init.");
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
size_t size = GetSize(input_shapes_);
|
||||
input_size_list_.push_back(size);
|
||||
|
||||
size = GetSize(indices_shapes_);
|
||||
input_size_list_.push_back(size);
|
||||
|
||||
size = GetSize(output_shapes_);
|
||||
output_size_list_.push_back(size);
|
||||
}
|
||||
|
||||
private:
|
||||
void Reshape() {
|
||||
size_t dim_of_indices = 1;
|
||||
for (size_t i = 0; i < indices_shapes_.size() - IntToSize(1); i++) {
|
||||
dim_of_indices *= indices_shapes_[i];
|
||||
}
|
||||
|
||||
size_t dim_after_indices = 1;
|
||||
size_t dim_indices_last = indices_shapes_[indices_shapes_.size() - IntToSize(1)];
|
||||
for (size_t i = dim_indices_last; i < input_shapes_.size(); i++) {
|
||||
dim_after_indices *= input_shapes_[i];
|
||||
}
|
||||
dims_.emplace_back(dim_of_indices);
|
||||
dims_.emplace_back(dim_after_indices);
|
||||
dims_.emplace_back(dim_indices_last);
|
||||
return;
|
||||
}
|
||||
size_t GetSize(const std::vector<size_t> &shape) const {
|
||||
if (shape.size() == 0) {
|
||||
return 0;
|
||||
}
|
||||
size_t result = sizeof(T);
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
result *= shape[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<size_t> input_shapes_;
|
||||
std::vector<size_t> indices_shapes_;
|
||||
std::vector<size_t> output_shapes_;
|
||||
|
||||
std::vector<size_t> dims_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
|
||||
std::vector<S> batch_strides_;
|
||||
std::vector<S> batch_indices_;
|
||||
|
||||
S *dev_batch_strides_;
|
||||
S *dev_batch_indices_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_GATHERND_GPU_KERNEL_H
|
@ -0,0 +1,33 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
ScatterNd,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ScatterNdGpuFwdKernel, float, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
ScatterNd,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
ScatterNdGpuFwdKernel, half, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
ScatterNd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ScatterNdGpuFwdKernel, int, int)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
@ -0,0 +1,175 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_SCATTER_ND_GPU_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_SCATTER_ND_GPU_KERNEL_H
|
||||
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cuh"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T, typename S>
|
||||
class ScatterNdGpuFwdKernel : public GpuKernel {
|
||||
public:
|
||||
ScatterNdGpuFwdKernel()
|
||||
: input_size_(1),
|
||||
indices_size_(1),
|
||||
output_size_(1),
|
||||
block_size_(1),
|
||||
indices_stride_(nullptr),
|
||||
work_shape_(nullptr),
|
||||
indices_dim_0_(0),
|
||||
indices_dim_1_(0) {}
|
||||
~ScatterNdGpuFwdKernel() {
|
||||
if (indices_stride_ != nullptr) {
|
||||
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast<void *>(indices_stride_));
|
||||
}
|
||||
if (work_shape_ != nullptr) {
|
||||
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast<void *>(work_shape_));
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
VARIABLE_NOT_USED(workspace);
|
||||
S *indices = GetDeviceAddress<S>(inputs, 0);
|
||||
T *update = GetDeviceAddress<T>(inputs, 1);
|
||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
ScatterNd(indices, update, output, block_size_, input_size_, output_size_, indices_dim_0_, indices_dim_1_,
|
||||
indices_stride_, work_shape_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 2) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but transpose needs 2 input.";
|
||||
return false;
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(ERROR) << "Output number is " << output_num << ", but transpose needs 1 output.";
|
||||
return false;
|
||||
}
|
||||
|
||||
input_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
indices_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
|
||||
vec_work_shape_ = GetAttr<std::vector<S>>(kernel_node, "shape");
|
||||
|
||||
GetSize();
|
||||
|
||||
size_t indices_len = sizeof(S) * vec_indices_stride_.size();
|
||||
void *indices_stride_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(indices_len);
|
||||
if (indices_stride_work == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to alloc indices_stride_work, size: " << indices_len;
|
||||
}
|
||||
indices_stride_ = static_cast<S *>(indices_stride_work);
|
||||
|
||||
size_t vec_work_len = sizeof(S) * vec_work_shape_.size();
|
||||
void *work_shape_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(vec_work_len);
|
||||
if (work_shape_work == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to alloc work_shape_work, size: " << vec_work_len;
|
||||
}
|
||||
work_shape_ = static_cast<S *>(work_shape_work);
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
cudaMemcpy(indices_stride_, &vec_indices_stride_[0], indices_len, cudaMemcpyHostToDevice),
|
||||
"cudaMemcpy failed in ScatterNdGpuFwdKernel::Init.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpy(work_shape_, &vec_work_shape_[0], vec_work_len, cudaMemcpyHostToDevice),
|
||||
"cudaMemcpy failed in ScatterNdGpuFwdKernel::Init.");
|
||||
InitSizeLists();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(indices_size_);
|
||||
input_size_list_.push_back(input_size_);
|
||||
output_size_list_.push_back(output_size_);
|
||||
return;
|
||||
}
|
||||
|
||||
void GetSize() {
|
||||
indices_size_ = sizeof(S);
|
||||
for (size_t i = 0; i < indices_shapes_.size(); i++) {
|
||||
indices_size_ *= indices_shapes_[i];
|
||||
}
|
||||
input_size_ = sizeof(T);
|
||||
for (size_t i = 0; i < input_shapes_.size(); i++) {
|
||||
input_size_ *= input_shapes_[i];
|
||||
}
|
||||
output_size_ = sizeof(T);
|
||||
for (size_t i = 0; i < output_shapes_.size(); i++) {
|
||||
output_size_ *= output_shapes_[i];
|
||||
}
|
||||
|
||||
// calculate indices dim 0/1
|
||||
indices_dim_0_ = indices_shapes_[0];
|
||||
indices_dim_1_ = indices_shapes_[1];
|
||||
|
||||
// calculate block_size
|
||||
for (size_t i = indices_dim_1_; i < output_shapes_.size(); i++) {
|
||||
block_size_ *= output_shapes_[i];
|
||||
}
|
||||
|
||||
// calculate indices_stride
|
||||
for (size_t i = 0; i < indices_dim_1_; i++) {
|
||||
vec_indices_stride_.push_back(0);
|
||||
}
|
||||
|
||||
vec_indices_stride_[indices_dim_1_ - 1] = block_size_;
|
||||
|
||||
for (size_t i = indices_dim_1_ - 1; i > 0; --i) {
|
||||
vec_indices_stride_[i - 1] = vec_indices_stride_[i] * output_shapes_[i];
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<size_t> input_shapes_;
|
||||
std::vector<size_t> indices_shapes_;
|
||||
std::vector<size_t> output_shapes_;
|
||||
std::vector<S> vec_indices_stride_;
|
||||
std::vector<S> vec_work_shape_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
|
||||
size_t input_size_;
|
||||
size_t indices_size_;
|
||||
size_t output_size_;
|
||||
size_t block_size_;
|
||||
|
||||
S *indices_stride_;
|
||||
S *work_shape_;
|
||||
size_t indices_dim_0_;
|
||||
size_t indices_dim_1_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_SCATTER_ND_GPU_KERNEL_H
|
@ -0,0 +1,81 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/boundingbox_decode_impl.cuh"
|
||||
|
||||
template <typename T>
|
||||
__global__ void BoundingBoxDecodeKernel(const size_t size, const T *rois, const T *deltas, T *bboxes, const float m1,
|
||||
const float m2, const float m3, const float m4, const float s1, const float s2,
|
||||
const float s3, const float s4, const int max_height, const int max_width,
|
||||
const float ratio_clip) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
const size_t left_x = i * 4;
|
||||
const size_t left_y = i * 4 + 1;
|
||||
const size_t right_x = i * 4 + 2;
|
||||
const size_t right_y = i * 4 + 3;
|
||||
|
||||
T dx = deltas[left_x] * s1 + m1;
|
||||
T dy = deltas[left_y] * s2 + m2;
|
||||
T dw = deltas[right_x] * s3 + m3;
|
||||
T dh = deltas[right_y] * s4 + m4;
|
||||
|
||||
T max_ratio = abs(log(ratio_clip));
|
||||
|
||||
dw = dw > max_ratio ? max_ratio : (dw < (-max_ratio) ? (-max_ratio) : dw);
|
||||
dh = dh > max_ratio ? max_ratio : (dh < (-max_ratio) ? (-max_ratio) : dh);
|
||||
|
||||
T px = (rois[left_x] + rois[right_x]) * 0.5f;
|
||||
T py = (rois[left_y] + rois[right_y]) * 0.5f;
|
||||
T pw = rois[right_x] - rois[left_x] + 1.0f;
|
||||
T ph = rois[right_y] - rois[left_y] + 1.0f;
|
||||
|
||||
T gx = px + pw * dx;
|
||||
T gy = py + ph * dy;
|
||||
T gw = pw * exp(dw);
|
||||
T gh = ph * exp(dh);
|
||||
|
||||
T x1 = gx - gw * 0.5f + 0.5f;
|
||||
T y1 = gy - gh * 0.5f + 0.5f;
|
||||
T x2 = gx + gw * 0.5f - 0.5f;
|
||||
T y2 = gy + gh * 0.5f - 0.5f;
|
||||
|
||||
x1 = x1 > max_width ? max_width : (x1 < 0 ? 0 : x1);
|
||||
y1 = y1 > max_height ? max_height : (y1 < 0 ? 0 : y1);
|
||||
x2 = x2 > max_width ? max_width : (x2 < 0 ? 0 : x2);
|
||||
y2 = y2 > max_height ? max_height : (y2 < 0 ? 0 : y2);
|
||||
|
||||
bboxes[left_x] = x1;
|
||||
bboxes[left_y] = y1;
|
||||
bboxes[right_x] = x2;
|
||||
bboxes[right_y] = y2;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BoundingBoxDecode(const size_t size, const T *rois, const T *deltas, T *bboxes, const float &m1, const float &m2,
|
||||
const float &m3, const float &m4, const float &s1, const float &s2, const float &s3,
|
||||
const float &s4, const int &max_height, const int &max_width, const float &ratio_clip,
|
||||
cudaStream_t cuda_stream) {
|
||||
BoundingBoxDecodeKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, rois, deltas, bboxes, m1, m2, m3, m4,
|
||||
s1, s2, s3, s4, max_height, max_width,
|
||||
ratio_clip);
|
||||
}
|
||||
|
||||
template void BoundingBoxDecode<float>(const size_t size, const float *rois, const float *deltas, float *bboxes,
|
||||
const float &m1, const float &m2, const float &m3, const float &m4,
|
||||
const float &s1, const float &s2, const float &s3, const float &s4,
|
||||
const int &max_height, const int &max_width, const float &ratio_clip,
|
||||
cudaStream_t cuda_stream);
|
@ -0,0 +1,27 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_DECODE_IMPL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_DECODE_IMPL_H_
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T>
|
||||
void BoundingBoxDecode(const size_t size, const T *rois, const T *deltas, T *bboxes, const float &m1, const float &m2,
|
||||
const float &m3, const float &m4, const float &s1, const float &s2, const float &s3,
|
||||
const float &s4, const int &max_height, const int &max_width, const float &ratio_clip,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_DECODE_IMPL_H_
|
@ -0,0 +1,62 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/boundingbox_encode_impl.cuh"
|
||||
|
||||
template <typename T>
|
||||
__global__ void BoundingBoxEncodeKernel(const size_t size, const T *anchor_box, const T *groundtruth_box, T *deltas,
|
||||
const float m1, const float m2, const float m3, const float m4, const float s1,
|
||||
const float s2, const float s3, const float s4) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
const size_t left_x = i * 4;
|
||||
const size_t left_y = i * 4 + 1;
|
||||
const size_t right_x = i * 4 + 2;
|
||||
const size_t right_y = i * 4 + 3;
|
||||
|
||||
T px = (anchor_box[left_x] + anchor_box[right_x]) * 0.5f;
|
||||
T py = (anchor_box[left_y] + anchor_box[right_y]) * 0.5f;
|
||||
T pw = anchor_box[right_x] - anchor_box[left_x] + 1.0f;
|
||||
T ph = anchor_box[right_y] - anchor_box[left_y] + 1.0f;
|
||||
|
||||
T gx = (groundtruth_box[left_x] + groundtruth_box[right_x]) * 0.5f;
|
||||
T gy = (groundtruth_box[left_y] + groundtruth_box[right_y]) * 0.5f;
|
||||
T gw = groundtruth_box[right_x] - groundtruth_box[left_x] + 1.0f;
|
||||
T gh = groundtruth_box[right_y] - groundtruth_box[left_y] + 1.0f;
|
||||
|
||||
T dx = (gx - px) / pw;
|
||||
T dy = (gy - py) / ph;
|
||||
T dw = log(gw / pw);
|
||||
T dh = log(gh / ph);
|
||||
|
||||
deltas[left_x] = (dx - m1) / s1;
|
||||
deltas[left_y] = (dy - m2) / s2;
|
||||
deltas[right_x] = (dw - m3) / s3;
|
||||
deltas[right_y] = (dh - m4) / s4;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BoundingBoxEncode(const size_t size, const T *anchor_box, const T *groundtruth_box, T *deltas, const float &m1,
|
||||
const float &m2, const float &m3, const float &m4, const float &s1, const float &s2,
|
||||
const float &s3, const float &s4, cudaStream_t cuda_stream) {
|
||||
BoundingBoxEncodeKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, anchor_box, groundtruth_box, deltas,
|
||||
m1, m2, m3, m4, s1, s2, s3, s4);
|
||||
}
|
||||
|
||||
template void BoundingBoxEncode<float>(const size_t size, const float *anchor_box, const float *groundtruth_box,
|
||||
float *deltas, const float &m1, const float &m2, const float &m3,
|
||||
const float &m4, const float &s1, const float &s2, const float &s3,
|
||||
const float &s4, cudaStream_t cuda_stream);
|
@ -0,0 +1,26 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_ENCODE_IMPL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_ENCODE_IMPL_H_
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T>
|
||||
void BoundingBoxEncode(const size_t size, const T *anchor_box, const T *groundtruth_box, T *deltas, const float &m1,
|
||||
const float &m2, const float &m3, const float &m4, const float &s1, const float &s2,
|
||||
const float &s3, const float &s4, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BOUNDINGBOX_ENCODE_IMPL_H_
|
@ -0,0 +1,65 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/gathernd.cuh"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T, typename S>
|
||||
__global__ void GatherNdKernel(T *input, S *indices, T *output, const size_t output_dim0, const size_t output_dim1,
|
||||
const size_t indices_dim1, S *batch_indices, S *batch_strides) {
|
||||
int num = output_dim0 * output_dim1;
|
||||
int i, j;
|
||||
for (int write_index = blockIdx.x * blockDim.x + threadIdx.x; write_index < num;
|
||||
write_index += blockDim.x * gridDim.x) {
|
||||
i = write_index / output_dim1 % output_dim0;
|
||||
j = write_index % output_dim1;
|
||||
|
||||
bool out_of_bound = false;
|
||||
int read_index = 0;
|
||||
int indices_i = 0;
|
||||
for (size_t k = 0; k < indices_dim1; k++) {
|
||||
size_t ind = indices_dim1 * i + k;
|
||||
indices_i = indices[ind];
|
||||
out_of_bound |= !(indices_i < batch_indices[k]);
|
||||
read_index += indices_i * batch_strides[k];
|
||||
}
|
||||
read_index += j;
|
||||
|
||||
if (!out_of_bound) {
|
||||
output[write_index] = input[read_index];
|
||||
} else {
|
||||
output[write_index] = 0;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
template <typename T, typename S>
|
||||
void GatherNd(T *input, S *indices, T *output, const size_t &output_dim0, const size_t &output_dim1,
|
||||
const size_t &indices_dim1, S *batch_indices, S *batch_strides, cudaStream_t stream) {
|
||||
int size = output_dim0 * output_dim1;
|
||||
GatherNdKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(input, indices, output, output_dim0, output_dim1,
|
||||
indices_dim1, batch_indices, batch_strides);
|
||||
return;
|
||||
}
|
||||
|
||||
template void GatherNd<float, int>(float *input, int *indices, float *output, const size_t &output_dim0,
|
||||
const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices,
|
||||
int *batch_strides, cudaStream_t stream);
|
||||
template void GatherNd<half, int>(half *input, int *indices, half *output, const size_t &output_dim0,
|
||||
const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices,
|
||||
int *batch_strides, cudaStream_t stream);
|
||||
template void GatherNd<int, int>(int *input, int *indices, int *output, const size_t &output_dim0,
|
||||
const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices,
|
||||
int *batch_strides, cudaStream_t stream);
|
@ -0,0 +1,26 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_GATHERND_GPU_CU_H
|
||||
#define MINDSPORE_GATHERND_GPU_CU_H
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T, typename S>
|
||||
void GatherNd(T *input, S *indices, T *output, const size_t &output_dim0, const size_t &output_dim1,
|
||||
const size_t &indices_dim1, S *batch_indices, S *batch_strides, cudaStream_t stream);
|
||||
|
||||
#endif // MINDSPORE_GATHERND_GPU_CU_H
|
@ -0,0 +1,68 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cuh"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
template <typename T, typename S>
|
||||
__global__ void ScatterNdKernel(S *indices, T *update, T *output, const size_t block_size, const size_t input_size,
|
||||
const size_t output_size, const size_t indices_dim_0, const size_t indices_dim_1,
|
||||
S *indices_stride, S *work_shape) {
|
||||
int i, j;
|
||||
for (int read_index = blockIdx.x * blockDim.x + threadIdx.x; read_index < input_size;
|
||||
read_index += blockDim.x * gridDim.x) {
|
||||
int write_index = 0;
|
||||
bool out_bound = false;
|
||||
|
||||
i = read_index / block_size;
|
||||
j = read_index % block_size;
|
||||
|
||||
for (size_t k = 0; k < indices_dim_1; k++) {
|
||||
S indices_i = indices[i * indices_dim_1 + k];
|
||||
out_bound |= indices_i >= work_shape[k];
|
||||
write_index += indices_i * indices_stride[k];
|
||||
}
|
||||
|
||||
write_index += j;
|
||||
out_bound |= write_index >= output_size;
|
||||
|
||||
if (!out_bound) {
|
||||
output[write_index] = update[read_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void ScatterNd(S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size,
|
||||
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride,
|
||||
S *work_shape, cudaStream_t stream) {
|
||||
ScatterNdKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(indices, update, output, block_size, input_size,
|
||||
output_size, indices_dim_0, indices_dim_1,
|
||||
indices_stride, work_shape);
|
||||
return;
|
||||
}
|
||||
|
||||
template void ScatterNd<float, int>(int *indices, float *update, float *output, const size_t &block_size,
|
||||
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
|
||||
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
|
||||
cudaStream_t stream);
|
||||
template void ScatterNd<half, int>(int *indices, half *update, half *output, const size_t &block_size,
|
||||
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
|
||||
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
|
||||
cudaStream_t stream);
|
||||
template void ScatterNd<int, int>(int *indices, int *update, int *output, const size_t &block_size,
|
||||
const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0,
|
||||
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
|
||||
cudaStream_t stream);
|
@ -0,0 +1,26 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_SCATTER_ND_GPU_CU_H
|
||||
#define MINDSPORE_SCATTER_ND_GPU_CU_H
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T, typename S>
|
||||
void ScatterNd(S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size,
|
||||
const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride,
|
||||
S *work_shape, cudaStream_t stream);
|
||||
#endif // MINDSPORE_SCATTER_ND_GPU_CU_H
|
@ -0,0 +1,57 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cuh"
|
||||
|
||||
template <typename T>
|
||||
__global__ void SGDKernel(const int size, const T dampening, const T weight_decay, const bool nesterov, const T *grad,
|
||||
const T *momentum, const T *lr, T *param, T *accum, T *stat) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) {
|
||||
T grad_new = grad[i];
|
||||
if (weight_decay != static_cast<T>(0)) {
|
||||
grad_new += param[i] * weight_decay;
|
||||
}
|
||||
|
||||
if (momentum[0] != static_cast<T>(0)) {
|
||||
if (stat[i] == static_cast<T>(0)) {
|
||||
accum[i] = grad_new;
|
||||
stat[i] = 0;
|
||||
} else {
|
||||
accum[i] = accum[i] * momentum[0] + (1.0 - dampening) * grad_new;
|
||||
}
|
||||
|
||||
if (nesterov) {
|
||||
grad_new += accum[i] * momentum[0];
|
||||
} else {
|
||||
grad_new = accum[i];
|
||||
}
|
||||
}
|
||||
|
||||
param[i] -= lr[0] * grad_new;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SGD(const int size, const T dampening, const T weight_decay, const bool nesterov, const T *lr, const T *momentum,
|
||||
const T *grad, T *param, T *accum, T *stat, cudaStream_t cuda_stream) {
|
||||
SGDKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, dampening, weight_decay, nesterov, grad, momentum,
|
||||
lr, param, accum, stat);
|
||||
}
|
||||
|
||||
template void SGD(const int size, const float dampening, const float weight_decay, const bool nesterov, const float *lr,
|
||||
const float *momentum, const float *grad, float *param, float *accum, float *stat,
|
||||
cudaStream_t cuda_stream);
|
@ -0,0 +1,25 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SGD_IMPL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SGD_IMPL_H_
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
void SGD(const int size, const T dampening, const T weight_decay, const bool nesterov, const T *lr, const T *momentum,
|
||||
const T *grad, T *param, T *accum, T *stat, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SGD_IMPL_H_
|
@ -0,0 +1,32 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/nn/sgd_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(SGD,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
SGDGpuKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
@ -0,0 +1,88 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SGD_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SGD_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class SGDGpuKernel : public GpuKernel {
|
||||
public:
|
||||
SGDGpuKernel() : size_(1), dampening_(0.0), weight_decay_(0.0), nesterov_(false) {}
|
||||
~SGDGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream) override {
|
||||
T *param = GetDeviceAddress<T>(inputs, 0);
|
||||
T *grad = GetDeviceAddress<T>(inputs, 1);
|
||||
T *lr = GetDeviceAddress<T>(inputs, 2);
|
||||
T *accum = GetDeviceAddress<T>(inputs, 3);
|
||||
T *momentum = GetDeviceAddress<T>(inputs, 4);
|
||||
T *stat = GetDeviceAddress<T>(inputs, 5);
|
||||
|
||||
SGD(size_, dampening_, weight_decay_, nesterov_, lr, momentum, grad, param, accum, stat,
|
||||
reinterpret_cast<cudaStream_t>(stream));
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
dampening_ = GetAttr<float>(kernel_node, "dampening");
|
||||
weight_decay_ = GetAttr<float>(kernel_node, "weight_decay");
|
||||
nesterov_ = GetAttr<bool>(kernel_node, "nesterov");
|
||||
|
||||
auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
for (auto &dim : input_shape) {
|
||||
size_ *= dim;
|
||||
}
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
size_t input_size = size_ * sizeof(T);
|
||||
input_size_list_.push_back(input_size); // parameter
|
||||
input_size_list_.push_back(input_size); // gradient
|
||||
input_size_list_.push_back(sizeof(T)); // lr
|
||||
input_size_list_.push_back(input_size); // accum
|
||||
input_size_list_.push_back(sizeof(T)); // momentum
|
||||
input_size_list_.push_back(input_size); // stat
|
||||
output_size_list_.push_back(input_size);
|
||||
}
|
||||
|
||||
private:
|
||||
size_t size_;
|
||||
float dampening_;
|
||||
float weight_decay_;
|
||||
bool nesterov_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SGD_KERNEL_H_
|
@ -0,0 +1,26 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/other/boundingbox_decode_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
BoundingBoxDecode,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
BoundingBoxDecodeGpuKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
@ -0,0 +1,152 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_DECODE_GPU_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_DECODE_GPU_KERNEL_H
|
||||
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/boundingbox_decode_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class BoundingBoxDecodeGpuKernel : public GpuKernel {
|
||||
public:
|
||||
BoundingBoxDecodeGpuKernel() : rois_size_(0), deltas_size_(0), bboxes_size_(0), wh_ratio_clip_(0.016) {}
|
||||
|
||||
~BoundingBoxDecodeGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
T *rois_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *deltas_addr = GetDeviceAddress<T>(inputs, 1);
|
||||
T *bboxes_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
if (inputs[0]->size != inputs[1]->size) {
|
||||
MS_LOG(ERROR) << "Rois box size must equal with deltas box size -" << inputs[1]->size << ", but got"
|
||||
<< inputs[0]->size;
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t coordinate = 4;
|
||||
const size_t block_size = inputs[0]->size / sizeof(T);
|
||||
if ((block_size % coordinate) != 0) {
|
||||
MS_LOG(ERROR) << "The size of the box must be a multiple of 4.";
|
||||
return false;
|
||||
}
|
||||
|
||||
BoundingBoxDecode(block_size / coordinate, rois_addr, deltas_addr, bboxes_addr, means_[0], means_[1], means_[2],
|
||||
means_[3], stds_[0], stds_[1], stds_[2], stds_[3], max_shape_[0], max_shape_[1], wh_ratio_clip_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 2) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but BoundingBoxDecode needs 2 inputs.";
|
||||
return false;
|
||||
}
|
||||
rois_size_ = sizeof(T);
|
||||
deltas_size_ = sizeof(T);
|
||||
bboxes_size_ = sizeof(T);
|
||||
|
||||
auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
for (size_t i = 0; i < logits_shape.size(); i++) {
|
||||
rois_size_ *= logits_shape[i];
|
||||
}
|
||||
|
||||
auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
for (size_t i = 0; i < labels_shape.size(); i++) {
|
||||
deltas_size_ *= labels_shape[i];
|
||||
}
|
||||
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
for (size_t i = 0; i < output_shape.size(); i++) {
|
||||
bboxes_size_ *= output_shape[i];
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
|
||||
const size_t coordinate_size = 4;
|
||||
if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<ValueTuple>() ||
|
||||
AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<ValueList>()) {
|
||||
means_ = GetAttr<std::vector<float>>(kernel_node, "means");
|
||||
} else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<FloatImm>()) {
|
||||
float mean = GetAttr<int>(kernel_node, "means");
|
||||
for (size_t i = 0; i < coordinate_size; i++) {
|
||||
means_.emplace_back(mean);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Attribute means type is invalid.";
|
||||
}
|
||||
|
||||
if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<ValueTuple>() ||
|
||||
AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<ValueList>()) {
|
||||
stds_ = GetAttr<std::vector<float>>(kernel_node, "stds");
|
||||
} else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<FloatImm>()) {
|
||||
float std = GetAttr<int>(kernel_node, "stds");
|
||||
for (size_t i = 0; i < coordinate_size; i++) {
|
||||
stds_.emplace_back(std);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Attribute stds type is invalid.";
|
||||
}
|
||||
|
||||
max_shape_ = GetAttr<std::vector<int>>(kernel_node, "max_shape");
|
||||
wh_ratio_clip_ = GetAttr<float>(kernel_node, "wh_ratio_clip");
|
||||
|
||||
if (means_.size() < coordinate_size || stds_.size() < coordinate_size) {
|
||||
MS_LOG(EXCEPTION) << "The size of means or stds is less than 4.";
|
||||
}
|
||||
|
||||
if (max_shape_.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "The size of max_shape is less than 2.";
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(rois_size_);
|
||||
input_size_list_.push_back(deltas_size_);
|
||||
output_size_list_.push_back(bboxes_size_);
|
||||
}
|
||||
|
||||
private:
|
||||
size_t rois_size_;
|
||||
size_t deltas_size_;
|
||||
size_t bboxes_size_;
|
||||
std::vector<float> means_;
|
||||
std::vector<float> stds_;
|
||||
std::vector<int> max_shape_;
|
||||
float wh_ratio_clip_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_DECODE_GPU_KERNEL_H
|
@ -0,0 +1,26 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/other/boundingbox_encode_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
BoundingBoxEncode,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
BoundingBoxEncodeGpuKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
@ -0,0 +1,143 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_ENCODE_GPU_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_ENCODE_GPU_KERNEL_H
|
||||
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/boundingbox_encode_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class BoundingBoxEncodeGpuKernel : public GpuKernel {
|
||||
public:
|
||||
BoundingBoxEncodeGpuKernel() : anchor_size_(0), groundtruth_size_(0), deltas_size_(0) {}
|
||||
|
||||
~BoundingBoxEncodeGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
T *anchor_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *groundtruth_addr = GetDeviceAddress<T>(inputs, 1);
|
||||
T *deltas_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
if (inputs[0]->size != inputs[1]->size) {
|
||||
MS_LOG(ERROR) << "Anchor box size must equal with groundtruth box size -" << inputs[1]->size << ", but got"
|
||||
<< inputs[0]->size;
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t coordinate = 4;
|
||||
const size_t block_size = inputs[0]->size / sizeof(T);
|
||||
if ((block_size % coordinate) != 0) {
|
||||
MS_LOG(ERROR) << "The size of the box must be a multiple of 4.";
|
||||
return false;
|
||||
}
|
||||
|
||||
BoundingBoxEncode(block_size / coordinate, anchor_addr, groundtruth_addr, deltas_addr, means_[0], means_[1],
|
||||
means_[2], means_[3], stds_[0], stds_[1], stds_[2], stds_[3],
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 2) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but BoundingBoxEncode needs 2 inputs.";
|
||||
return false;
|
||||
}
|
||||
anchor_size_ = sizeof(T);
|
||||
groundtruth_size_ = sizeof(T);
|
||||
deltas_size_ = sizeof(T);
|
||||
|
||||
auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
for (size_t i = 0; i < logits_shape.size(); i++) {
|
||||
anchor_size_ *= logits_shape[i];
|
||||
}
|
||||
|
||||
auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
for (size_t i = 0; i < labels_shape.size(); i++) {
|
||||
groundtruth_size_ *= labels_shape[i];
|
||||
}
|
||||
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
for (size_t i = 0; i < output_shape.size(); i++) {
|
||||
deltas_size_ *= output_shape[i];
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
|
||||
const size_t coordinate_size = 4;
|
||||
if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<ValueTuple>() ||
|
||||
AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<ValueList>()) {
|
||||
means_ = GetAttr<std::vector<float>>(kernel_node, "means");
|
||||
} else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("means")->isa<FloatImm>()) {
|
||||
float mean = GetAttr<int>(kernel_node, "means");
|
||||
for (size_t i = 0; i < coordinate_size; i++) {
|
||||
means_.emplace_back(mean);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Attribute means type is invalid.";
|
||||
}
|
||||
|
||||
if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<ValueTuple>() ||
|
||||
AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<ValueList>()) {
|
||||
stds_ = GetAttr<std::vector<float>>(kernel_node, "stds");
|
||||
} else if (AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stds")->isa<FloatImm>()) {
|
||||
float std = GetAttr<int>(kernel_node, "stds");
|
||||
for (size_t i = 0; i < coordinate_size; i++) {
|
||||
stds_.emplace_back(std);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Attribute stds type is invalid.";
|
||||
}
|
||||
|
||||
if (means_.size() < coordinate_size || stds_.size() < coordinate_size) {
|
||||
MS_LOG(EXCEPTION) << "The size of means or stds is less than 4.";
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(anchor_size_);
|
||||
input_size_list_.push_back(groundtruth_size_);
|
||||
output_size_list_.push_back(deltas_size_);
|
||||
}
|
||||
|
||||
private:
|
||||
size_t anchor_size_;
|
||||
size_t groundtruth_size_;
|
||||
size_t deltas_size_;
|
||||
std::vector<float> means_;
|
||||
std::vector<float> stds_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_OTHER_BOUNDINGBOX_ENCODE_GPU_KERNEL_H
|
@ -0,0 +1,60 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class NetBoundingBoxDecode(nn.Cell):
|
||||
def __init__(self, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)):
|
||||
super(NetBoundingBoxDecode, self).__init__()
|
||||
self.decode = P.BoundingBoxDecode(max_shape=(768, 1280), means=means, stds=stds,
|
||||
wh_ratio_clip=0.016)
|
||||
|
||||
def construct(self, anchor, groundtruth):
|
||||
return self.decode(anchor, groundtruth)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_boundingbox_decode():
|
||||
anchor = np.array([[4, 1, 2, 1], [2, 2, 2, 3]], np.float32)
|
||||
deltas = np.array([[3, 1, 2, 2], [1, 2, 1, 4]], np.float32)
|
||||
means = (0.1, 0.1, 0.2, 0.2)
|
||||
stds = (2.0, 2.0, 3.0, 3.0)
|
||||
anchor_box = Tensor(anchor, mindspore.float32)
|
||||
deltas_box = Tensor(deltas, mindspore.float32)
|
||||
expect_deltas = np.array([[28.6500, 0.0000, 0.0000, 33.8500],
|
||||
[0.0000, 0.0000, 15.8663, 72.7000]], np.float32)
|
||||
|
||||
error = np.ones(shape=[2, 4]) * 1.0e-4
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
boundingbox_decode = NetBoundingBoxDecode(means, stds)
|
||||
output = boundingbox_decode(anchor_box, deltas_box)
|
||||
diff = output.asnumpy() - expect_deltas
|
||||
assert np.all(abs(diff) < error)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
boundingbox_decode = NetBoundingBoxDecode(means, stds)
|
||||
output = boundingbox_decode(anchor_box, deltas_box)
|
||||
diff = output.asnumpy() - expect_deltas
|
||||
assert np.all(abs(diff) < error)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue