parent
b606b84e6c
commit
7829bab811
@ -0,0 +1,45 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cuh"
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void CheckValidKernel(const size_t size, const T *box, const T *img_metas, S *valid) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
const size_t left_x = i * 4;
|
||||
const size_t left_y = i * 4 + 1;
|
||||
const size_t right_x = i * 4 + 2;
|
||||
const size_t right_y = i * 4 + 3;
|
||||
|
||||
S valid_flag = false;
|
||||
valid_flag |= !(box[left_x] >= 0.f);
|
||||
valid_flag |= !(box[left_y] >= 0.f);
|
||||
valid_flag |= !(img_metas[0] * img_metas[2] - 1.f >= box[right_x]);
|
||||
valid_flag |= !(img_metas[1] * img_metas[2] - 1.f >= box[right_y]);
|
||||
|
||||
valid[i] = !valid_flag;
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void CheckValid(const size_t &size, const T *box, const T *img_metas, S *valid, cudaStream_t cuda_stream) {
|
||||
CheckValidKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, box, img_metas, valid);
|
||||
}
|
||||
|
||||
template void CheckValid(const size_t &size, const float *box, const float *img_metas, bool *valid,
|
||||
cudaStream_t cuda_stream);
|
@ -0,0 +1,25 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_CHECK_VALID_IMPL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_CHECK_VALID_IMPL_H_
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T, typename S>
|
||||
void CheckValid(const size_t &size, const T *box, const T *img_metas, S *valid, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_CHECK_VALID_IMPL_H_
|
@ -0,0 +1,72 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/iou_impl.cuh"
|
||||
|
||||
template <typename T>
|
||||
__device__ T CoordinateMax(const T a, const T b) {
|
||||
return (a > b ? a : b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ T CoordinateMin(const T a, const T b) {
|
||||
return (a < b ? a : b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void IOUKernel(const size_t size, const T *box1, const T *box2, T *iou_results, const size_t mode,
|
||||
const size_t input_len_0) {
|
||||
T location_coordinate[IOU_LOCATION_NUM][IOU_DIMENSION];
|
||||
T overlaps_coordinate[IOU_DIMENSION];
|
||||
const T epsilon = 1e-10;
|
||||
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||
for (size_t j = 0; j < IOU_DIMENSION; j++) {
|
||||
location_coordinate[0][j] = box1[(i % input_len_0) * IOU_DIMENSION + j];
|
||||
location_coordinate[1][j] = box2[(i / input_len_0) * IOU_DIMENSION + j];
|
||||
}
|
||||
|
||||
overlaps_coordinate[0] = CoordinateMax(location_coordinate[0][0], location_coordinate[1][0]);
|
||||
overlaps_coordinate[1] = CoordinateMax(location_coordinate[0][1], location_coordinate[1][1]);
|
||||
overlaps_coordinate[2] = CoordinateMin(location_coordinate[0][2], location_coordinate[1][2]);
|
||||
overlaps_coordinate[3] = CoordinateMin(location_coordinate[0][3], location_coordinate[1][3]);
|
||||
|
||||
T overlaps_w = CoordinateMax(0.f, overlaps_coordinate[2] - overlaps_coordinate[0] + 1);
|
||||
T overlaps_h = CoordinateMax(0.f, overlaps_coordinate[3] - overlaps_coordinate[1] + 1);
|
||||
T overlaps = overlaps_w * overlaps_h;
|
||||
|
||||
T area1 = (location_coordinate[0][2] - location_coordinate[0][0] + 1) * (location_coordinate[0][3] -
|
||||
location_coordinate[0][1] + 1);
|
||||
if (mode == 0) {
|
||||
T area2 = (location_coordinate[1][2] - location_coordinate[1][0] + 1) * (location_coordinate[1][3] -
|
||||
location_coordinate[1][1] + 1);
|
||||
iou_results[i] = overlaps / (area1 + area2 - overlaps + epsilon);
|
||||
} else {
|
||||
iou_results[i] = overlaps / (area1 + epsilon);
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void IOU(const size_t &size, const T *box1, const T *box2, T *iou_results, const size_t &mode,
|
||||
const size_t &input_len_0, cudaStream_t cuda_stream) {
|
||||
IOUKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, box1, box2, iou_results, mode, input_len_0);
|
||||
}
|
||||
|
||||
template void IOU(const size_t &size, const float *box1, const float *box2, float *iou_results, const size_t &mode,
|
||||
const size_t &input_len_0, cudaStream_t cuda_stream);
|
@ -0,0 +1,29 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_IOU_IMPL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_IOU_IMPL_H_
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
#define IOU_LOCATION_NUM 2
|
||||
#define IOU_DIMENSION 4
|
||||
|
||||
template <typename T>
|
||||
void IOU(const size_t &size, const T *box1, const T *box2, T *iou_results, const size_t &mode,
|
||||
const size_t &input_len_0, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_IOU_IMPL_H_
|
@ -0,0 +1,26 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/other/check_valid_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
CheckValid,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
CheckValidGpuKernel, float, bool)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
@ -0,0 +1,106 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_OTHER_CHECK_VALID_GPU_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_OTHER_CHECK_VALID_GPU_KERNEL_H
|
||||
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T, typename S>
|
||||
class CheckValidGpuKernel : public GpuKernel {
|
||||
public:
|
||||
CheckValidGpuKernel() : anchor_boxes_size_(0), img_metas_size_(0), valid_size_(0) {}
|
||||
|
||||
~CheckValidGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
VARIABLE_NOT_USED(workspace);
|
||||
T *anchor_boxes_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *img_metas_addr = GetDeviceAddress<T>(inputs, 1);
|
||||
S *valid_addr = GetDeviceAddress<S>(outputs, 0);
|
||||
|
||||
const size_t coordinate = 4;
|
||||
const size_t block_size = inputs[0]->size / sizeof(T);
|
||||
if ((block_size % coordinate) != 0) {
|
||||
MS_LOG(ERROR) << "The size of the box must be a multiple of 4.";
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t size = block_size / coordinate;
|
||||
CheckValid(size, anchor_boxes_addr, img_metas_addr, valid_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 2) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but CheckValid needs 2 inputs.";
|
||||
return false;
|
||||
}
|
||||
anchor_boxes_size_ = sizeof(T);
|
||||
img_metas_size_ = sizeof(T);
|
||||
valid_size_ = sizeof(S);
|
||||
|
||||
auto anchor_boxes_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
for (size_t i = 0; i < anchor_boxes_shape.size(); i++) {
|
||||
anchor_boxes_size_ *= anchor_boxes_shape[i];
|
||||
}
|
||||
|
||||
auto img_metas_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
for (size_t i = 0; i < img_metas_shape.size(); i++) {
|
||||
img_metas_size_ *= img_metas_shape[i];
|
||||
}
|
||||
|
||||
auto valid_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
for (size_t i = 0; i < valid_shape.size(); i++) {
|
||||
valid_size_ *= valid_shape[i];
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(anchor_boxes_size_);
|
||||
input_size_list_.push_back(img_metas_size_);
|
||||
output_size_list_.push_back(valid_size_);
|
||||
}
|
||||
|
||||
private:
|
||||
size_t anchor_boxes_size_;
|
||||
size_t img_metas_size_;
|
||||
size_t valid_size_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_OTHER_CHECK_VALID_GPU_KERNEL_H
|
@ -0,0 +1,25 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/other/iou_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
IOU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
IOUGpuKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
@ -0,0 +1,122 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_OTHER_IOU_GPU_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_OTHER_IOU_GPU_KERNEL_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/iou_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class IOUGpuKernel : public GpuKernel {
|
||||
public:
|
||||
IOUGpuKernel() : gt_boxes_size_(0), anchor_boxes_size_(0), iou_size_(0), mode_(0) {}
|
||||
|
||||
~IOUGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
VARIABLE_NOT_USED(workspace);
|
||||
T *gt_boxes_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *anchor_boxes_addr = GetDeviceAddress<T>(inputs, 1);
|
||||
T *iou_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
const size_t coordinate = 4;
|
||||
const size_t block_size_0 = inputs[0]->size / sizeof(T);
|
||||
const size_t block_size_1 = inputs[1]->size / sizeof(T);
|
||||
if ((block_size_0 % coordinate) != 0 || (block_size_1 % coordinate) != 0) {
|
||||
MS_LOG(ERROR) << "The size of the box must be a multiple of 4.";
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t input_len_0 = block_size_0 / coordinate;
|
||||
const size_t input_len_1 = block_size_1 / coordinate;
|
||||
IOU(input_len_0 * input_len_1, gt_boxes_addr, anchor_boxes_addr, iou_addr, mode_, input_len_0,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 2) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but IOU needs 2 inputs.";
|
||||
return false;
|
||||
}
|
||||
gt_boxes_size_ = sizeof(T);
|
||||
anchor_boxes_size_ = sizeof(T);
|
||||
iou_size_ = sizeof(T);
|
||||
|
||||
auto gt_boxes_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
for (size_t i = 0; i < gt_boxes_shape.size(); i++) {
|
||||
gt_boxes_size_ *= gt_boxes_shape[i];
|
||||
}
|
||||
|
||||
auto anchor_boxes_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
for (size_t i = 0; i < anchor_boxes_shape.size(); i++) {
|
||||
anchor_boxes_size_ *= anchor_boxes_shape[i];
|
||||
}
|
||||
|
||||
auto iou_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
for (size_t i = 0; i < iou_shape.size(); i++) {
|
||||
iou_size_ *= iou_shape[i];
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
|
||||
std::string mode = GetAttr<std::string>(kernel_node, "mode");
|
||||
|
||||
if (mode == "iou") {
|
||||
mode_ = 0;
|
||||
} else if (mode == "iof") {
|
||||
mode_ = 1;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Mode only support 'iou' or 'iof'.";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(gt_boxes_size_);
|
||||
input_size_list_.push_back(anchor_boxes_size_);
|
||||
output_size_list_.push_back(iou_size_);
|
||||
}
|
||||
|
||||
private:
|
||||
size_t gt_boxes_size_;
|
||||
size_t anchor_boxes_size_;
|
||||
size_t iou_size_;
|
||||
size_t mode_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_OTHER_IOU_GPU_KERNEL_H
|
@ -0,0 +1,54 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class NetCheckValid(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetCheckValid, self).__init__()
|
||||
self.valid = P.CheckValid()
|
||||
|
||||
def construct(self, anchor, image_metas):
|
||||
return self.valid(anchor, image_metas)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_boundingbox_decode():
|
||||
anchor = np.array([[50, 0, 100, 700], [-2, 2, 8, 100], [10, 20, 300, 2000]], np.float32)
|
||||
image_metas = np.array([768, 1280, 1], np.float32)
|
||||
anchor_box = Tensor(anchor, mindspore.float32)
|
||||
image_metas_box = Tensor(image_metas, mindspore.float32)
|
||||
expect = np.array([True, False, False], np.bool_)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
boundingbox_decode = NetCheckValid()
|
||||
output = boundingbox_decode(anchor_box, image_metas_box)
|
||||
diff = (output.asnumpy() == expect)
|
||||
assert (diff == 1).all()
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
boundingbox_decode = NetCheckValid()
|
||||
output = boundingbox_decode(anchor_box, image_metas_box)
|
||||
diff = (output.asnumpy() == expect)
|
||||
assert (diff == 1).all()
|
@ -0,0 +1,57 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class NetIOU(nn.Cell):
|
||||
def __init__(self, mode):
|
||||
super(NetIOU, self).__init__()
|
||||
self.encode = P.IOU(mode=mode)
|
||||
|
||||
def construct(self, anchor, groundtruth):
|
||||
return self.encode(anchor, groundtruth)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_iou():
|
||||
pos1 = [101, 169, 246, 429]
|
||||
pos2 = [121, 138, 304, 374]
|
||||
mode = "iou"
|
||||
pos1_box = Tensor(np.array(pos1).reshape(1, 4), mindspore.float32)
|
||||
pos2_box = Tensor(np.array(pos2).reshape(1, 4), mindspore.float32)
|
||||
expect_result = np.array(0.46551168, np.float32)
|
||||
|
||||
error = np.ones(shape=[1]) * 1.0e-6
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
overlaps = NetIOU(mode)
|
||||
output = overlaps(pos1_box, pos2_box)
|
||||
diff = output.asnumpy() - expect_result
|
||||
assert np.all(abs(diff) < error)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
overlaps = NetIOU(mode)
|
||||
output = overlaps(pos1_box, pos2_box)
|
||||
diff = output.asnumpy() - expect_result
|
||||
assert np.all(abs(diff) < error)
|
Loading…
Reference in new issue