add cpu_ops_maximum()

pull/9357/head
dinglongwei 4 years ago
parent 51d885815a
commit 8cd631ca8b

@ -0,0 +1,210 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/maximum_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
namespace kernel {
template <typename T>
void MaximumCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
input_x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
input_y_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
TypeId input_x_dtype = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
TypeId input_y_dtype = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
size_t max_input_shape_size =
input_x_shape_.size() > input_y_shape_.size() ? input_x_shape_.size() : input_y_shape_.size();
for (size_t i = 0; i < output_shape_.size(); i++) {
output_num_ *= output_shape_[i];
}
if ((input_x_shape_.size() == 0 && input_y_shape_.size() != 0) ||
(input_x_shape_.size() != 0 && input_y_shape_.size() == 0)) {
InitInputTensorAndScalar(max_input_shape_size);
} else if (max_input_shape_size == output_shape_.size() && output_shape_.size() != 0) {
InitInputTensors(input_x_dtype, input_y_dtype);
} else {
MS_LOG(EXCEPTION) << "Only support input two tensors or one tensor and one scalar";
}
}
template <typename T>
void MaximumCPUKernel<T>::InitInputTensorAndScalar(size_t max_input_shape_size) {
if (max_input_shape_size != output_shape_.size()) {
MS_LOG(EXCEPTION) << "Output tensor size must be equal to the max shape size of inputs";
}
need_broadcast_ = false;
}
template <typename T>
void MaximumCPUKernel<T>::InitInputTensors(TypeId input_x_dtype, TypeId input_y_dtype) {
if (input_x_dtype == kNumberTypeBool && input_y_dtype == kNumberTypeBool) {
MS_LOG(EXCEPTION) << "Input tensor types cannot be both bool";
}
// Check if the shape needs to be broadcast
need_broadcast_ = IsBroadcast();
if (need_broadcast_) {
InitTensorBroadcastShape();
}
}
template <typename T>
bool MaximumCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
T *input_x_ = reinterpret_cast<T *>(inputs[0]->addr);
T *input_y_ = reinterpret_cast<T *>(inputs[1]->addr);
T *output_ = reinterpret_cast<T *>(outputs[0]->addr);
BroadcastArith(input_x_, input_y_, output_);
return true;
}
template <typename T>
void MaximumCPUKernel<T>::BroadcastArith(const T *input_x, const T *input_y, T *output) {
MS_EXCEPTION_IF_NULL(input_x);
MS_EXCEPTION_IF_NULL(input_y);
MS_EXCEPTION_IF_NULL(output);
if (need_broadcast_) {
BroadcastArithKernel(broadcast_input_x_shape_[0], broadcast_input_x_shape_[1], broadcast_input_x_shape_[2],
broadcast_input_x_shape_[3], broadcast_input_x_shape_[4], broadcast_input_x_shape_[5],
broadcast_input_x_shape_[6], broadcast_input_y_shape_[0], broadcast_input_y_shape_[1],
broadcast_input_y_shape_[2], broadcast_input_y_shape_[3], broadcast_input_y_shape_[4],
broadcast_input_y_shape_[5], broadcast_input_y_shape_[6], broadcast_output_shape_[0],
broadcast_output_shape_[1], broadcast_output_shape_[2], broadcast_output_shape_[3],
broadcast_output_shape_[4], broadcast_output_shape_[5], broadcast_output_shape_[6], input_x,
input_y, output);
} else {
if (input_x_shape_.size() == 0 || input_y_shape_.size() == 0) {
BroadcastArithOneScalarOneTensor(input_x, input_y, output);
} else {
BroadcastArithTensors(input_x, input_y, output);
}
}
}
template <typename T>
bool MaximumCPUKernel<T>::IsBroadcast() {
if (input_x_shape_.size() != input_y_shape_.size()) {
return true;
}
for (size_t i = 0; i < input_x_shape_.size(); i++) {
if (input_x_shape_[i] != input_y_shape_[i]) {
return true;
}
}
return false;
}
template <typename T>
void MaximumCPUKernel<T>::InitTensorBroadcastShape() {
if (output_shape_.size() > max_dims) {
MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 7";
}
broadcast_input_x_shape_.resize(max_dims, 1);
broadcast_input_y_shape_.resize(max_dims, 1);
broadcast_output_shape_.resize(max_dims, 1);
for (size_t i = 0; i < output_shape_.size(); i++) {
broadcast_output_shape_[i] = output_shape_[i];
}
int input_x_dim_offset = output_shape_.size() - input_x_shape_.size();
for (size_t j = 0; j < input_x_shape_.size(); j++) {
broadcast_input_x_shape_[j + input_x_dim_offset] = input_x_shape_[j];
input_x_num_ *= input_x_shape_[j];
}
int input_y_dim_offset = output_shape_.size() - input_y_shape_.size();
for (size_t k = 0; k < input_y_shape_.size(); k++) {
if (need_broadcast_) {
broadcast_input_y_shape_[k + input_y_dim_offset] = input_y_shape_[k];
input_y_num_ *= input_y_shape_[k];
}
}
}
// Broadcast comparation
template <typename T>
size_t MaximumCPUKernel<T>::Index(const size_t &index, const size_t &dim) {
return dim == 1 ? 0 : index;
}
// Broadcast Arithmetic
template <typename T>
void MaximumCPUKernel<T>::BroadcastArithKernel(const size_t l0, const size_t l1, const size_t l2, const size_t l3,
const size_t l4, const size_t l5, const size_t l6, const size_t r0,
const size_t r1, const size_t r2, const size_t r3, const size_t r4,
const size_t r5, const size_t r6, const size_t d0, const size_t d1,
const size_t d2, const size_t d3, const size_t d4, const size_t d5,
const size_t d6, const T *input_x, const T *input_y, T *output) {
MS_EXCEPTION_IF_NULL(input_x);
MS_EXCEPTION_IF_NULL(input_y);
MS_EXCEPTION_IF_NULL(output);
for (size_t pos = 0; pos < output_num_; pos++) {
size_t i = pos / (d1 * d2 * d3 * d4 * d5 * d6) % d0;
size_t j = pos / (d2 * d3 * d4 * d5 * d6) % d1;
size_t k = pos / (d3 * d4 * d5 * d6) % d2;
size_t l = pos / (d4 * d5 * d6) % d3;
size_t m = pos / (d5 * d6) % d4;
size_t n = pos / d6 % d5;
size_t o = pos % d6;
size_t l_index = Index(i, l0) * l1 * l2 * l3 * l4 * l5 * l6;
l_index += Index(j, l1) * l2 * l3 * l4 * l5 * l6;
l_index += Index(k, l2) * l3 * l4 * l5 * l6;
l_index += Index(l, l3) * l4 * l5 * l6;
l_index += Index(m, l4) * l5 * l6;
l_index += Index(n, l5) * l6;
l_index += Index(o, l6);
size_t r_index = Index(i, r0) * r1 * r2 * r3 * r4 * r5 * r6;
r_index += Index(j, r1) * r2 * r3 * r4 * r5 * r6;
r_index += Index(k, r2) * r3 * r4 * r5 * r6;
r_index += Index(l, r3) * r4 * r5 * r6;
r_index += Index(m, r4) * r5 * r6;
r_index += Index(n, r5) * r6;
r_index += Index(o, r6);
output[pos] = MaximumFunc(input_x[l_index], input_y[r_index]);
}
}
template <typename T>
void MaximumCPUKernel<T>::BroadcastArithOneScalarOneTensor(const T *input_x, const T *input_y, T *output) {
MS_EXCEPTION_IF_NULL(input_x);
MS_EXCEPTION_IF_NULL(input_y);
MS_EXCEPTION_IF_NULL(output);
if (input_x_shape_.size() == 0) {
for (size_t i = 0; i < output_num_; ++i) {
output[i] = MaximumFunc(input_x[0], input_y[i]);
}
} else {
for (size_t i = 0; i < output_num_; ++i) {
output[i] = MaximumFunc(input_x[i], input_y[0]);
}
}
}
template <typename T>
void MaximumCPUKernel<T>::BroadcastArithTensors(const T *input_x, const T *input_y, T *output) {
MS_EXCEPTION_IF_NULL(input_x);
MS_EXCEPTION_IF_NULL(input_y);
MS_EXCEPTION_IF_NULL(output);
for (size_t i = 0; i < output_num_; ++i) {
output[i] = MaximumFunc(input_x[i], input_y[i]);
}
}
} // namespace kernel
} // namespace mindspore

@ -0,0 +1,122 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MAXIMUM_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MAXIMUM_CPU_KERNEL_H_
#include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
template <typename T>
class MaximumCPUKernel : public CPUKernel {
public:
MaximumCPUKernel() = default;
~MaximumCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
bool IsBroadcast();
size_t Index(const size_t &index, const size_t &dim);
void InitTensorBroadcastShape();
void InitInputTensorAndScalar(size_t max_input_shape_size);
void InitInputTensors(TypeId input_x_dtype, TypeId input_y_dtype);
// Broadcast Arithmetic
void BroadcastArithKernel(const size_t l0, const size_t l1, const size_t l2, const size_t l3, const size_t l4,
const size_t l5, const size_t l6, const size_t r0, const size_t r1, const size_t r2,
const size_t r3, const size_t r4, const size_t r5, const size_t r6, const size_t d0,
const size_t d1, const size_t d2, const size_t d3, const size_t d4, const size_t d5,
const size_t d6, const T *input_x, const T *input_y, T *output);
T MaximumFunc(const T &lhs, const T &rhs) { return lhs > rhs ? lhs : rhs; }
void BroadcastArithOneScalarOneTensor(const T *input_x, const T *input_y, T *output);
void BroadcastArithTensors(const T *input_x, const T *input_y, T *output);
void BroadcastArith(const T *input_x, const T *input_y, T *output);
private:
bool need_broadcast_{false};
size_t input_x_num_{1};
size_t input_y_num_{1};
size_t output_num_{1};
std::vector<size_t> input_x_shape_;
std::vector<size_t> input_y_shape_;
std::vector<size_t> output_shape_;
std::vector<size_t> broadcast_input_x_shape_;
std::vector<size_t> broadcast_input_y_shape_;
std::vector<size_t> broadcast_output_shape_;
const size_t max_dims{7};
};
MS_REG_CPU_KERNEL_T(
Maximum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
MaximumCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(
Maximum,
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
MaximumCPUKernel, uint32_t);
MS_REG_CPU_KERNEL_T(
Maximum,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
MaximumCPUKernel, float);
MS_REG_CPU_KERNEL_T(
Maximum, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
MaximumCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(
Maximum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32),
MaximumCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(
Maximum, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
MaximumCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(
Maximum, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64),
MaximumCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(
Maximum, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
MaximumCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(
Maximum,
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
MaximumCPUKernel, uint64_t);
MS_REG_CPU_KERNEL_T(
Maximum,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
MaximumCPUKernel, double);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_UPDATE_CACHE_CPU_KERNEL_H_

@ -1867,7 +1867,7 @@ class Maximum(_MathBinaryOp):
and the data type is the one with higher precision or higher digits among the two inputs.
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input_x = Tensor(np.array([1.0, 5.0, 3.0]), mindspore.float32)

@ -0,0 +1,193 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore.common.tensor import Tensor
from mindspore.nn import Cell
from mindspore.ops import operations as P
class ConstScalarAndTensorMaximum(Cell):
def __init__(self):
super(ConstScalarAndTensorMaximum, self).__init__()
self.max = P.Maximum()
self.x = 20
def construct(self, y):
return self.max(self.x, y)
class TwoTensorsMaximum(Cell):
def __init__(self):
super(TwoTensorsMaximum, self).__init__()
self.max = P.Maximum()
def construct(self, x, y):
return self.max(x, y)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.env_onecard
def test_maximum_constScalar_tensor_int():
x = Tensor(np.array([[2, 3, 4], [100, 200, 300]]).astype(np.int32))
expect = [[20, 20, 20], [100, 200, 300]]
error = np.ones(shape=[2, 3]) * 1.0e-5
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
max_op = ConstScalarAndTensorMaximum()
output = max_op(x)
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(-diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.env_onecard
def test_maximum_two_tensors_Not_Broadcast_int():
x = Tensor(np.array([[2, 3, 4], [100, 200, 300]]).astype(np.int32))
y = Tensor(np.array([[1, 2, 3], [100, 100, 200]]).astype(np.int32))
expect = [[2, 3, 4], [100, 200, 300]]
error = np.ones(shape=[2, 3]) * 1.0e-5
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
max_op = TwoTensorsMaximum()
output = max_op(x, y)
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(-diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.env_onecard
def test_maximum_two_tensors_Broadcast_int():
x = Tensor(np.array([[2, 3, 4], [100, 200, 300]]).astype(np.int32))
y = Tensor(np.array([[100, 100, 200]]).astype(np.int32))
expect = [[100, 100, 200], [100, 200, 300]]
error = np.ones(shape=[2, 3]) * 1.0e-5
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
max_op = TwoTensorsMaximum()
output = max_op(x, y)
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(-diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.env_onecard
def test_maximum_two_tensors_Broadcast_oneDimension_int():
x = Tensor(np.array([[2, 3, 4], [100, 200, 300]]).astype(np.int32))
y = Tensor(np.array([[100]]).astype(np.int32))
expect = [[100, 100, 100], [100, 200, 300]]
error = np.ones(shape=[2, 3]) * 1.0e-5
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
max_op = TwoTensorsMaximum()
output = max_op(x, y)
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(-diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.env_onecard
def test_maximum_two_tensors_notBroadcast_all_oneDimension_int():
x = Tensor(np.array([[2]]).astype(np.int32))
y = Tensor(np.array([[100]]).astype(np.int32))
expect = [[100]]
error = np.ones(shape=[1, 1]) * 1.0e-5
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
max_op = TwoTensorsMaximum()
output = max_op(x, y)
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(-diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.env_onecard
def test_maximum_two_tensors_Broadcast_bool():
x = Tensor(np.array([[2, 2]]).astype(np.int32))
y = Tensor(np.array([[True, False], [False, False]]).astype(np.bool_))
expect = [[2, 2], [2, 2]]
error = np.ones(shape=[2, 2]) * 1.0e-5
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
max_op = TwoTensorsMaximum()
output = max_op(x, y)
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(-diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.env_onecard
def test_maximum_two_tensors_notBroadcast_bool():
x = Tensor(np.array([[2, 2], [-1, 100]]).astype(np.int32))
y = Tensor(np.array([[True, False], [False, False]]).astype(np.bool_))
expect = [[2, 2], [0, 100]]
error = np.ones(shape=[2, 2]) * 1.0e-5
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
max_op = TwoTensorsMaximum()
output = max_op(x, y)
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(-diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.env_onecard
def test_maximum_two_tensors_notBroadcast_float32():
x = Tensor(np.array([[2.0, 2.0], [-1, 100]]).astype(np.float32))
y = Tensor(np.array([[1.0, 2.1], [-0.8, 100.5]]).astype(np.float32))
expect = [[2.0, 2.1], [-0.8, 100.5]]
error = np.ones(shape=[2, 2]) * 1.0e-5
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
max_op = TwoTensorsMaximum()
output = max_op(x, y)
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(-diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.env_onecard
def test_maximum_two_tensors_notBroadcast_float64():
x = Tensor(np.array([[2.0, 2.0], [-1, 100]]).astype(np.float64))
y = Tensor(np.array([[1.0, 2.1], [-0.8, 100.5]]).astype(np.float64))
expect = [[2.0, 2.1], [-0.8, 100.5]]
error = np.ones(shape=[2, 2]) * 1.0e-5
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
max_op = TwoTensorsMaximum()
output = max_op(x, y)
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(-diff < error)
Loading…
Cancel
Save