diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/argmin_with_value_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/argmin_with_value_cpu_kernel.cc new file mode 100644 index 0000000000..5471109341 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/argmin_with_value_cpu_kernel.cc @@ -0,0 +1,105 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/cpu/argmin_with_value_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +size_t get_element_num(const std::vector &shape) { + size_t size = 1; + for (size_t i = 0; i < shape.size(); i++) { + size *= shape[i]; + } + return size; +} + +template +bool check_validation(const std::vector &shape, const size_t num_before_axis, const size_t num_after_axis, + const std::vector &inputs, const std::vector &outputs) { + if (inputs.size() != 1 || outputs.size() != 2) { + MS_LOG(EXCEPTION) << "Wrong number of inputs or outputs!"; + return false; + } + size_t data_size = sizeof(T); + size_t input_size = get_element_num(shape) * data_size; + size_t output_num = num_before_axis * num_after_axis; + size_t out0_size = output_num * sizeof(int); + size_t out1_size = output_num * data_size; + if (inputs[0]->size != input_size || outputs[0]->size != out0_size || outputs[1]->size != out1_size) { + MS_LOG(EXCEPTION) << "invalid input or output data size!"; + return false; + } + return true; +} +} // namespace + +template +void ArgMinWithValueCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + size_t shape_len = shape_.size(); + int64_t axis = AnfAlgo::GetNodeAttr(kernel_node, AXIS); + axis += shape_len; + if (axis < 0) { + MS_LOG(EXCEPTION) << "Invalid axis:" << axis << ", should in range [-1, " << shape_len - 1 << "]"; + } + axis = axis % static_cast(shape_len); + num_before_axis_ = 1; + num_after_axis_ = 1; + for (size_t i = 0; i < shape_len; i++) { + if (static_cast(i) < axis) { + num_before_axis_ *= shape_[i]; + } else if (static_cast(i) > axis) { + num_after_axis_ *= shape_[i]; + } + } + dim_axis_ = shape_[axis]; +} + +template +bool ArgMinWithValueCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspaces*/, + const std::vector &outputs) { + if (!check_validation(shape_, num_before_axis_, num_after_axis_, inputs, outputs)) { + return false; + } + + auto input = reinterpret_cast(inputs[0]->addr); + auto output0 = reinterpret_cast(outputs[0]->addr); + auto output1 = reinterpret_cast(outputs[1]->addr); + + for (size_t i = 0; i < num_before_axis_; i++) { + size_t src_index_i = i * dim_axis_ * num_after_axis_; + for (size_t j = 0; j < num_after_axis_; j++) { + std::vector array_axis; + size_t src_index_j = src_index_i + j; + for (size_t k = 0; k < dim_axis_; k++) { + size_t src_index_k = k * num_after_axis_ + src_index_j; + array_axis.push_back(static_cast(input[src_index_k])); + } + auto min_ops = std::min_element(array_axis.begin(), array_axis.end()); + auto min_index = static_cast(std::distance(array_axis.begin(), min_ops)); + auto dst_index = i * num_after_axis_ + j; + output0[dst_index] = min_index; + auto src_index = IntToSize(min_index) * num_after_axis_ + src_index_j; + output1[dst_index] = input[src_index]; + } + } + return true; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/argmin_with_value_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/argmin_with_value_cpu_kernel.h new file mode 100644 index 0000000000..cc3141cc61 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/argmin_with_value_cpu_kernel.h @@ -0,0 +1,56 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ARGMINWITHVALUE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ARGMINWITHVALUE_CPU_KERNEL_H_ +#include +#include +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class ArgMinWithValueCPUKernel : public CPUKernel { + public: + ArgMinWithValueCPUKernel() = default; + ~ArgMinWithValueCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + std::vector shape_; + size_t num_before_axis_; + size_t num_after_axis_; + size_t dim_axis_; +}; + +MS_REG_CPU_KERNEL_T( + ArgMinWithValue, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + ArgMinWithValueCPUKernel, float); +MS_REG_CPU_KERNEL_T( + ArgMinWithValue, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), + ArgMinWithValueCPUKernel, float16); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ARGMINWITHVALUE_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc index c18fa2fdd6..038a4c861f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc @@ -79,6 +79,46 @@ void ArithmeticCPUKernel::RealDiv(const T *input1, const T *input2, T *out, size } } +template +void ArithmeticCPUKernel::Div(const T *input1, const T *input2, T *out, size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + std::vector idx; + GenIndex(i, &idx); + auto dividend = input1[idx[0]]; + auto divisor = input2[idx[1]]; + if (divisor == 0) { + if (dividend == 0) { + out[i] = std::numeric_limits::quiet_NaN(); + continue; + } + if (std::numeric_limits::has_infinity) { + out[i] = dividend > 0 ? std::numeric_limits::infinity() : -std::numeric_limits::infinity(); + } else { + out[i] = dividend > 0 ? std::numeric_limits::max() : std::numeric_limits::min(); + } + continue; + } + out[i] = dividend / divisor; + } +} + +template +void ArithmeticCPUKernel::Mod(const T *input1, const T *input2, T *out, size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + std::vector idx; + GenIndex(i, &idx); + auto x = static_cast(input1[idx[0]]); + auto y = static_cast(input2[idx[1]]); + auto data_div = x / y; + auto data_div_min = data_div < 0.0 ? data_div : 0.0; + auto data_div_max = data_div > 0.0 ? data_div : 0.0; + auto data_div_max_floor = floor(data_div_max); + auto data_div_min_ceil = ceil(data_div_min); + auto data_div_res = data_div_max_floor + data_div_min_ceil; + out[i] = static_cast(x - data_div_res * y); + } +} + template void ArithmeticCPUKernel::Pow(const T *input1, const T *input2, T *out, size_t start, size_t end) { for (size_t i = start; i < end; i++) { @@ -128,6 +168,10 @@ void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) { operate_type_ = MUL; } else if (kernel_name == prim::kPrimRealDiv->name()) { operate_type_ = REALDIV; + } else if (kernel_name == prim::kPrimDiv->name()) { + operate_type_ = DIV; + } else if (kernel_name == prim::kPrimMod->name()) { + operate_type_ = MOD; } else if (kernel_name == prim::kPrimPow->name()) { operate_type_ = POW; } else if (kernel_name == prim::kPrimLess->name()) { @@ -291,6 +335,10 @@ void ArithmeticCPUKernel::LaunchKernel(const std::vector &inputs, co threads.emplace_back(std::thread(&ArithmeticCPUKernel::Mul, this, input1, input2, output, start, end)); } else if (operate_type_ == REALDIV) { threads.emplace_back(std::thread(&ArithmeticCPUKernel::RealDiv, this, input1, input2, output, start, end)); + } else if (operate_type_ == DIV) { + threads.emplace_back(std::thread(&ArithmeticCPUKernel::Div, this, input1, input2, output, start, end)); + } else if (operate_type_ == MOD) { + threads.emplace_back(std::thread(&ArithmeticCPUKernel::Mod, this, input1, input2, output, start, end)); } else if (operate_type_ == POW) { threads.emplace_back(std::thread(&ArithmeticCPUKernel::Pow, this, input1, input2, output, start, end)); } else if (operate_type_ == ASSIGNADD) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h index 51b2a4044e..b8990a7568 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h @@ -48,6 +48,10 @@ class ArithmeticCPUKernel : public CPUKernel { template void RealDiv(const T *input1, const T *input2, T *out, size_t start, size_t end); template + void Div(const T *input1, const T *input2, T *out, size_t start, size_t end); + template + void Mod(const T *input1, const T *input2, T *out, size_t start, size_t end); + template void Pow(const T *input1, const T *input2, T *out, size_t start, size_t end); template void AssignAdd(T *input1, const T *input2, T *out, size_t start, size_t end); @@ -96,6 +100,24 @@ MS_REG_CPU_KERNEL( MS_REG_CPU_KERNEL( RealDiv, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), ArithmeticCPUKernel); +MS_REG_CPU_KERNEL( + Div, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + ArithmeticCPUKernel); +MS_REG_CPU_KERNEL( + Div, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ArithmeticCPUKernel); +MS_REG_CPU_KERNEL( + Div, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ArithmeticCPUKernel); +MS_REG_CPU_KERNEL( + Mod, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ArithmeticCPUKernel); +MS_REG_CPU_KERNEL( + Mod, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ArithmeticCPUKernel); +MS_REG_CPU_KERNEL( + Mod, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + ArithmeticCPUKernel); MS_REG_CPU_KERNEL( Less, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), ArithmeticCPUKernel); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc index a4323eef04..049c7d823b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc @@ -62,6 +62,13 @@ void ZerosLike(const T *in, T *out, size_t start, size_t end) { out[i] = static_cast(0); } } + +template +void Floor(const T *in, T *out, size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = static_cast(floor(in[i])); + } +} } // namespace void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) { @@ -77,6 +84,8 @@ void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) { operate_type_ = NEG; } else if (kernel_name == prim::kPrimSign->name()) { operate_type_ = SIGN; + } else if (kernel_name == prim::kPrimFloor->name()) { + operate_type_ = FLOOR; } dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); } @@ -128,6 +137,8 @@ void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector &inputs threads.emplace_back(std::thread(ZerosLike, input, output, start, end)); } else if (operate_type_ == SIGN) { threads.emplace_back(std::thread(Sign, input, output, start, end)); + } else if (operate_type_ == FLOOR) { + threads.emplace_back(std::thread(Floor, input, output, start, end)); } start += once_compute_size; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h index 0b40dc68a4..9b4b3f36c1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h @@ -58,6 +58,8 @@ MS_REG_CPU_KERNEL(Sign, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputA ArithmeticSelfCPUKernel); MS_REG_CPU_KERNEL(Sign, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), ArithmeticSelfCPUKernel); +MS_REG_CPU_KERNEL(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ArithmeticSelfCPUKernel); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h index f6a449377a..63cb502993 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h @@ -63,6 +63,7 @@ enum OperateType { SQRT, POW, REALDIV, + MOD, NEG, LESS, ASSIGNADD, @@ -77,6 +78,7 @@ enum OperateType { SIGN, EQUAL, NOTEQUAL, + FLOOR, }; class CPUKernel : public kernel::KernelMod { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/minimum_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/minimum_cpu_kernel.cc new file mode 100644 index 0000000000..9dc0c46db0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/minimum_cpu_kernel.cc @@ -0,0 +1,221 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/minimum_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { + +template +void MinimumCPUKernel::InitKernel(const CNodePtr &kernel_node) { + CheckParam(kernel_node); + input_x_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + input_y_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); + TypeId input_x_dtype = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); + TypeId input_y_dtype = AnfAlgo::GetInputDeviceDataType(kernel_node, 1); + size_t max_input_shape_size = + input_x_shape_.size() > input_y_shape_.size() ? input_x_shape_.size() : input_y_shape_.size(); + for (size_t i = 0; i < output_shape_.size(); i++) { + output_num_ *= output_shape_[i]; + } + if ((input_x_shape_.size() == 0 && input_y_shape_.size() != 0) || + (input_x_shape_.size() != 0 && input_y_shape_.size() == 0)) { + InitInputTensorAndScalar(max_input_shape_size); + } else if (max_input_shape_size == output_shape_.size() && output_shape_.size() != 0) { + InitInputTensors(input_x_dtype, input_y_dtype); + } else { + MS_LOG(EXCEPTION) << "Only support input two tensors or one tensor and one scalar"; + } +} + +template +void MinimumCPUKernel::CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but MinimumCPUKernel needs 2 input."; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but MinimumCPUKernel needs 1 output."; + } +} + +template +void MinimumCPUKernel::InitInputTensorAndScalar(size_t max_input_shape_size) { + if (max_input_shape_size != output_shape_.size()) { + MS_LOG(EXCEPTION) << "Output tensor size must be equal to the max shape size of inputs"; + } + need_broadcast_ = false; +} + +template +void MinimumCPUKernel::InitInputTensors(TypeId input_x_dtype, TypeId input_y_dtype) { + if (input_x_dtype == kNumberTypeBool && input_y_dtype == kNumberTypeBool) { + MS_LOG(EXCEPTION) << "Input tensor types cannot be both bool"; + } + // Check if the shape needs to be broadcast + need_broadcast_ = IsBroadcast(); + if (need_broadcast_) { + InitTensorBroadcastShape(); + } +} + +template +bool MinimumCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + T *input_x_ = reinterpret_cast(inputs[0]->addr); + T *input_y_ = reinterpret_cast(inputs[1]->addr); + T *output_ = reinterpret_cast(outputs[0]->addr); + BroadcastArith(input_x_, input_y_, output_); + return true; +} + +template +void MinimumCPUKernel::BroadcastArith(const T *input_x, const T *input_y, T *output) { + MS_EXCEPTION_IF_NULL(input_x); + MS_EXCEPTION_IF_NULL(input_y); + MS_EXCEPTION_IF_NULL(output); + if (need_broadcast_) { + BroadcastArithKernel(broadcast_input_x_shape_[0], broadcast_input_x_shape_[1], broadcast_input_x_shape_[2], + broadcast_input_x_shape_[3], broadcast_input_x_shape_[4], broadcast_input_x_shape_[5], + broadcast_input_x_shape_[6], broadcast_input_y_shape_[0], broadcast_input_y_shape_[1], + broadcast_input_y_shape_[2], broadcast_input_y_shape_[3], broadcast_input_y_shape_[4], + broadcast_input_y_shape_[5], broadcast_input_y_shape_[6], broadcast_output_shape_[0], + broadcast_output_shape_[1], broadcast_output_shape_[2], broadcast_output_shape_[3], + broadcast_output_shape_[4], broadcast_output_shape_[5], broadcast_output_shape_[6], input_x, + input_y, output); + } else { + if (input_x_shape_.size() == 0 || input_y_shape_.size() == 0) { + BroadcastArithOneScalarOneTensor(input_x, input_y, output); + } else { + BroadcastArithTensors(input_x, input_y, output); + } + } +} + +template +bool MinimumCPUKernel::IsBroadcast() { + if (input_x_shape_.size() != input_y_shape_.size()) { + return true; + } + for (size_t i = 0; i < input_x_shape_.size(); i++) { + if (input_x_shape_[i] != input_y_shape_[i]) { + return true; + } + } + return false; +} + +template +void MinimumCPUKernel::InitTensorBroadcastShape() { + if (output_shape_.size() > max_dims) { + MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 7"; + } + broadcast_input_x_shape_.resize(max_dims, 1); + broadcast_input_y_shape_.resize(max_dims, 1); + broadcast_output_shape_.resize(max_dims, 1); + for (size_t i = 0; i < output_shape_.size(); i++) { + broadcast_output_shape_[i] = output_shape_[i]; + } + int input_x_dim_offset = output_shape_.size() - input_x_shape_.size(); + for (size_t j = 0; j < input_x_shape_.size(); j++) { + broadcast_input_x_shape_[j + input_x_dim_offset] = input_x_shape_[j]; + input_x_num_ *= input_x_shape_[j]; + } + int input_y_dim_offset = output_shape_.size() - input_y_shape_.size(); + for (size_t k = 0; k < input_y_shape_.size(); k++) { + if (need_broadcast_) { + broadcast_input_y_shape_[k + input_y_dim_offset] = input_y_shape_[k]; + input_y_num_ *= input_y_shape_[k]; + } + } +} + +// Broadcast comparation +template +size_t MinimumCPUKernel::Index(const size_t &index, const size_t &dim) { + return dim == 1 ? 0 : index; +} + +// Broadcast Arithmetic +template +void MinimumCPUKernel::BroadcastArithKernel(const size_t l0, const size_t l1, const size_t l2, const size_t l3, + const size_t l4, const size_t l5, const size_t l6, const size_t r0, + const size_t r1, const size_t r2, const size_t r3, const size_t r4, + const size_t r5, const size_t r6, const size_t d0, const size_t d1, + const size_t d2, const size_t d3, const size_t d4, const size_t d5, + const size_t d6, const T *input_x, const T *input_y, T *output) { + MS_EXCEPTION_IF_NULL(input_x); + MS_EXCEPTION_IF_NULL(input_y); + MS_EXCEPTION_IF_NULL(output); + for (size_t pos = 0; pos < output_num_; pos++) { + size_t i = pos / (d1 * d2 * d3 * d4 * d5 * d6) % d0; + size_t j = pos / (d2 * d3 * d4 * d5 * d6) % d1; + size_t k = pos / (d3 * d4 * d5 * d6) % d2; + size_t l = pos / (d4 * d5 * d6) % d3; + size_t m = pos / (d5 * d6) % d4; + size_t n = pos / d6 % d5; + size_t o = pos % d6; + + size_t l_index = Index(i, l0) * l1 * l2 * l3 * l4 * l5 * l6; + l_index += Index(j, l1) * l2 * l3 * l4 * l5 * l6; + l_index += Index(k, l2) * l3 * l4 * l5 * l6; + l_index += Index(l, l3) * l4 * l5 * l6; + l_index += Index(m, l4) * l5 * l6; + l_index += Index(n, l5) * l6; + l_index += Index(o, l6); + size_t r_index = Index(i, r0) * r1 * r2 * r3 * r4 * r5 * r6; + r_index += Index(j, r1) * r2 * r3 * r4 * r5 * r6; + r_index += Index(k, r2) * r3 * r4 * r5 * r6; + r_index += Index(l, r3) * r4 * r5 * r6; + r_index += Index(m, r4) * r5 * r6; + r_index += Index(n, r5) * r6; + r_index += Index(o, r6); + output[pos] = MinimumFunc(input_x[l_index], input_y[r_index]); + } +} + +template +void MinimumCPUKernel::BroadcastArithOneScalarOneTensor(const T *input_x, const T *input_y, T *output) { + MS_EXCEPTION_IF_NULL(input_x); + MS_EXCEPTION_IF_NULL(input_y); + MS_EXCEPTION_IF_NULL(output); + if (input_x_shape_.size() == 0) { + for (size_t i = 0; i < output_num_; ++i) { + output[i] = MinimumFunc(input_x[0], input_y[i]); + } + } else { + for (size_t i = 0; i < output_num_; ++i) { + output[i] = MinimumFunc(input_x[i], input_y[0]); + } + } +} + +template +void MinimumCPUKernel::BroadcastArithTensors(const T *input_x, const T *input_y, T *output) { + MS_EXCEPTION_IF_NULL(input_x); + MS_EXCEPTION_IF_NULL(input_y); + MS_EXCEPTION_IF_NULL(output); + for (size_t i = 0; i < output_num_; ++i) { + output[i] = MinimumFunc(input_x[i], input_y[i]); + } +} + +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/minimum_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/minimum_cpu_kernel.h new file mode 100644 index 0000000000..9c5d0aa3e9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/minimum_cpu_kernel.h @@ -0,0 +1,108 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MINIMUM_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_MINIMUM_CPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class MinimumCPUKernel : public CPUKernel { + public: + MinimumCPUKernel() = default; + ~MinimumCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + private: + void CheckParam(const CNodePtr &kernel_node); + + bool IsBroadcast(); + + size_t Index(const size_t &index, const size_t &dim); + + void InitTensorBroadcastShape(); + + void InitInputTensorAndScalar(size_t max_input_shape_size); + + void InitInputTensors(TypeId input_x_dtype, TypeId input_y_dtype); + + // Broadcast Arithmetic + void BroadcastArithKernel(const size_t l0, const size_t l1, const size_t l2, const size_t l3, const size_t l4, + const size_t l5, const size_t l6, const size_t r0, const size_t r1, const size_t r2, + const size_t r3, const size_t r4, const size_t r5, const size_t r6, const size_t d0, + const size_t d1, const size_t d2, const size_t d3, const size_t d4, const size_t d5, + const size_t d6, const T *input_x, const T *input_y, T *output); + + T MinimumFunc(const T &lhs, const T &rhs) { return lhs < rhs ? lhs : rhs; } + + void BroadcastArithOneScalarOneTensor(const T *input_x, const T *input_y, T *output); + + void BroadcastArithTensors(const T *input_x, const T *input_y, T *output); + + void BroadcastArith(const T *input_x, const T *input_y, T *output); + + private: + bool need_broadcast_{false}; + size_t input_x_num_{1}; + size_t input_y_num_{1}; + size_t output_num_{1}; + std::vector input_x_shape_; + std::vector input_y_shape_; + std::vector output_shape_; + std::vector broadcast_input_x_shape_; + std::vector broadcast_input_y_shape_; + std::vector broadcast_output_shape_; + const size_t max_dims{7}; +}; + +MS_REG_CPU_KERNEL_T( + Minimum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + MinimumCPUKernel, int32_t); + +MS_REG_CPU_KERNEL_T( + Minimum, + KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), + MinimumCPUKernel, uint32_t); + +MS_REG_CPU_KERNEL_T( + Minimum, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + MinimumCPUKernel, float); + +MS_REG_CPU_KERNEL_T( + Minimum, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + MinimumCPUKernel, int64_t); + +MS_REG_CPU_KERNEL_T( + Minimum, + KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), + MinimumCPUKernel, uint64_t); + +MS_REG_CPU_KERNEL_T( + Minimum, + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + MinimumCPUKernel, double); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_UPDATE_CACHE_CPU_KERNEL_H_ diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 8dc9a318c2..f635a70731 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -243,6 +243,8 @@ inline const PrimitivePtr kPrimNeg = std::make_shared("Neg"); inline const PrimitivePtr kPrimSub = std::make_shared("Sub"); inline const PrimitivePtr kPrimMul = std::make_shared("Mul"); inline const PrimitivePtr kPrimDiv = std::make_shared("Div"); +inline const PrimitivePtr kPrimMod = std::make_shared("Mod"); +inline const PrimitivePtr kPrimFloor = std::make_shared("Floor"); inline const PrimitivePtr kPrimDivNoNan = std::make_shared("DivNoNan"); inline const PrimitivePtr kPrimMinimum = std::make_shared("Minimum"); inline const PrimitivePtr kPrimMaximum = std::make_shared("Maximum"); diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index a5a2458a65..8b3cb73d77 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1708,7 +1708,7 @@ class ArgMinWithValue(PrimitiveWithInfer): - output_x (Tensor) - The minimum value of input tensor, with the same shape as index. Supported Platforms: - ``Ascend`` + ``Ascend`` ``CPU`` Examples: >>> input_x = Tensor(np.array([0.0, 0.4, 0.6, 0.7, 0.1]), mindspore.float32) diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 03cccdcbfb..8cf7fa37c1 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -1833,7 +1833,7 @@ class Minimum(_MathBinaryOp): and the data type is the one with higher precision or higher digits among the two inputs. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> input_x = Tensor(np.array([1.0, 5.0, 3.0]), mindspore.float32) @@ -1963,7 +1963,7 @@ class Div(_MathBinaryOp): and the data type is the one with higher precision or higher digits among the two inputs. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> input_x = Tensor(np.array([-4.0, 5.0, 6.0]), mindspore.float32) @@ -2158,7 +2158,7 @@ class Mod(_MathBinaryOp): ValueError: When `input_x` and `input_y` are not the same dtype. Supported Platforms: - ``Ascend`` + ``Ascend`` ``CPU`` Examples: >>> input_x = Tensor(np.array([-4.0, 5.0, 6.0]), mindspore.float32) @@ -2188,7 +2188,7 @@ class Floor(PrimitiveWithInfer): Tensor, has the same shape as `input_x`. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> input_x = Tensor(np.array([1.1, 2.5, -1.5]), mindspore.float32) diff --git a/tests/st/ops/cpu/test_argminwithvalue_op.py b/tests/st/ops/cpu/test_argminwithvalue_op.py new file mode 100644 index 0000000000..b6e387d35a --- /dev/null +++ b/tests/st/ops/cpu/test_argminwithvalue_op.py @@ -0,0 +1,139 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class NetArgminWithValue(nn.Cell): + def __init__(self, axis=0, keep_dims=False): + super(NetArgminWithValue, self).__init__() + self.argmin = P.ArgMinWithValue(axis=axis, keep_dims=keep_dims) + + def construct(self, x): + return self.argmin(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_argminwithvalue_fp32(): + x = np.array([[1., 20., 5.], + [67., 8., 9.], + [130., 24., 15.], + [-0.5, 25, 100]]).astype(np.float32) + argmin_a0 = NetArgminWithValue(axis=0, keep_dims=False) + + output0, output1 = argmin_a0(Tensor(x)) + expect0 = np.array([3, 1, 0]).astype(np.int32) + expect1 = np.array([-0.5, 8., 5.]).astype(np.float32) + error = np.ones(shape=expect1.shape) * 1.0e-6 + assert np.all(output0.asnumpy() == expect0) + assert np.all(np.abs(output1.asnumpy() - expect1) < error) + + argmin_a0k = NetArgminWithValue(axis=0, keep_dims=True) + + output0, output1 = argmin_a0k(Tensor(x)) + expect0 = np.array([[3, 1, 0]]).astype(np.int32) + expect1 = np.array([[-0.5, 8., 5.]]).astype(np.float32) + error = np.ones(shape=expect1.shape) * 1.0e-6 + assert np.all(output0.asnumpy() == expect0) + assert np.all(np.abs(output1.asnumpy() - expect1) < error) + + argmin_a1 = NetArgminWithValue(axis=1, keep_dims=False) + + output0, output1 = argmin_a1(Tensor(x)) + expect0 = np.array([0, 1, 2, 0]).astype(np.int32) + expect1 = np.array([1., 8., 15., -0.5]).astype(np.float32) + error = np.ones(shape=expect1.shape) * 1.0e-6 + assert np.all(output0.asnumpy() == expect0) + assert np.all(np.abs(output1.asnumpy() - expect1) < error) + + argmin_a1k = NetArgminWithValue(axis=-1, keep_dims=True) + + output0, output1 = argmin_a1k(Tensor(x)) + expect0 = np.array([[0], [1], [2], [0]]).astype(np.int32) + expect1 = np.array([[1.], [8.], [15.], [-0.5]]).astype(np.float32) + error = np.ones(shape=expect1.shape) * 1.0e-6 + assert np.all(output0.asnumpy() == expect0) + assert np.all(np.abs(output1.asnumpy() - expect1) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_argminwithvalue_fp16(): + x = np.array([[1., 20., 5.], + [67., 8., 9.], + [130., 24., 15.], + [-0.5, 25, 100]]).astype(np.float16) + argmin_a0 = NetArgminWithValue(axis=0, keep_dims=False) + + output0, output1 = argmin_a0(Tensor(x)) + expect0 = np.array([3, 1, 0]).astype(np.int32) + expect1 = np.array([-0.5, 8., 5.]).astype(np.float16) + error = np.ones(shape=expect1.shape) * 1.0e-6 + assert np.all(output0.asnumpy() == expect0) + assert np.all(np.abs(output1.asnumpy() - expect1) < error) + + argmin_a0k = NetArgminWithValue(axis=0, keep_dims=True) + + output0, output1 = argmin_a0k(Tensor(x)) + expect0 = np.array([[3, 1, 0]]).astype(np.int32) + expect1 = np.array([[-0.5, 8., 5.]]).astype(np.float16) + error = np.ones(shape=expect1.shape) * 1.0e-6 + assert np.all(output0.asnumpy() == expect0) + assert np.all(np.abs(output1.asnumpy() - expect1) < error) + + argmin_a1 = NetArgminWithValue(axis=1, keep_dims=False) + + output0, output1 = argmin_a1(Tensor(x)) + expect0 = np.array([0, 1, 2, 0]).astype(np.int32) + expect1 = np.array([1., 8., 15., -0.5]).astype(np.float16) + error = np.ones(shape=expect1.shape) * 1.0e-6 + assert np.all(output0.asnumpy() == expect0) + assert np.all(np.abs(output1.asnumpy() - expect1) < error) + + argmin_a1k = NetArgminWithValue(axis=-1, keep_dims=True) + + output0, output1 = argmin_a1k(Tensor(x)) + expect0 = np.array([[0], [1], [2], [0]]).astype(np.int32) + expect1 = np.array([[1.], [8.], [15.], [-0.5]]).astype(np.float16) + error = np.ones(shape=expect1.shape) * 1.0e-6 + assert np.all(output0.asnumpy() == expect0) + assert np.all(np.abs(output1.asnumpy() - expect1) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_argminwithvalue_tensor(): + prop = 100 if np.random.random() > 0.5 else -100 + x = np.random.randn(3, 4, 5, 6).astype(np.float16) * prop + argmin_a0 = NetArgminWithValue(axis=-2, keep_dims=False) + + output0, output1 = argmin_a0(Tensor(x)) + expect0 = np.argmin(x, axis=-2) + expect1 = np.min(x, axis=-2).astype(np.float16) + error = np.ones(shape=expect1.shape) * 1.0e-6 + assert np.all(output0.asnumpy() == expect0) + assert np.all(np.abs(output1.asnumpy() - expect1) < error) diff --git a/tests/st/ops/cpu/test_arithmetic_op.py b/tests/st/ops/cpu/test_arithmetic_op.py index 069a40653c..516f7c2073 100644 --- a/tests/st/ops/cpu/test_arithmetic_op.py +++ b/tests/st/ops/cpu/test_arithmetic_op.py @@ -33,6 +33,24 @@ class SubNet(nn.Cell): return self.sub(x, y) +class DivNet(nn.Cell): + def __init__(self): + super(DivNet, self).__init__() + self.div = P.Div() + + def construct(self, x, y): + return self.div(x, y) + + +class ModNet(nn.Cell): + def __init__(self): + super(ModNet, self).__init__() + self.mod = P.Mod() + + def construct(self, x, y): + return self.mod(x, y) + + @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard @@ -43,4 +61,194 @@ def test_sub(): output = net(Tensor(x), Tensor(y, mindspore.float32)) expect_output = x - y assert np.all(output.asnumpy() == expect_output) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_div(): + prop = 1 if np.random.random() < 0.5 else -1 + x0_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.float32) * prop + y0_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.float32) * prop + x1_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.float32) * prop + y1_np = np.random.randint(1, 100, (2, 1, 4, 4)).astype(np.float32) * prop + x2_np = np.random.randint(1, 100, (2, 1, 1, 4)).astype(np.float16) * prop + y2_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.float16) * prop + x3_np = np.random.randint(1, 100, 1).astype(np.float32) * prop + y3_np = np.random.randint(1, 100, 1).astype(np.float32) * prop + x4_np = np.array(768).astype(np.float32) * prop + y4_np = np.array(3072.5).astype(np.float32) * prop + x5_np = np.random.randint(1, 100, (2, 1, 1, 4)).astype(np.int32) * prop + y5_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.int32) * prop + x6_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.int32) * prop + y6_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.float32) * prop + x7_np = np.random.randint(1, 100, (2, 1, 1, 4)).astype(np.int64) * prop + y7_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.int64) * prop + + x0 = Tensor(x0_np) + y0 = Tensor(y0_np) + x1 = Tensor(x1_np) + y1 = Tensor(y1_np) + x2 = Tensor(x2_np) + y2 = Tensor(y2_np) + x3 = Tensor(x3_np) + y3 = Tensor(y3_np) + x4 = Tensor(x4_np) + y4 = Tensor(y4_np) + x5 = Tensor(x5_np) + y5 = Tensor(y5_np) + x6 = Tensor(x6_np) + y6 = Tensor(y6_np) + x7 = Tensor(x7_np) + y7 = Tensor(y7_np) + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + div = DivNet() + output0 = div(x0, y0) + expect0 = np.divide(x0_np, y0_np) + diff0 = output0.asnumpy() - expect0 + error0 = np.ones(shape=expect0.shape) * 1.0e-5 + assert np.all(diff0 < error0) + assert output0.shape == expect0.shape + + output1 = div(x1, y1) + expect1 = np.divide(x1_np, y1_np) + diff1 = output1.asnumpy() - expect1 + error1 = np.ones(shape=expect1.shape) * 1.0e-5 + assert np.all(diff1 < error1) + assert output1.shape == expect1.shape + + output2 = div(x2, y2) + expect2 = np.divide(x2_np, y2_np).astype(np.float16) + diff2 = output2.asnumpy() - expect2 + error2 = np.ones(shape=expect2.shape) * 1.0e-5 + assert np.all(diff2 < error2) + assert output2.shape == expect2.shape + + output3 = div(x3, y3) + expect3 = np.divide(x3_np, y3_np) + diff3 = output3.asnumpy() - expect3 + error3 = np.ones(shape=expect3.shape) * 1.0e-5 + assert np.all(diff3 < error3) + assert output3.shape == expect3.shape + + output4 = div(x4, y4) + expect4 = np.divide(x4_np, y4_np) + diff4 = output4.asnumpy() - expect4 + error4 = np.ones(shape=expect4.shape) * 1.0e-5 + assert np.all(diff4 < error4) + assert output4.shape == expect4.shape + + output5 = div(x5, y5) + expect5 = x5_np // y5_np + assert np.all(output5.asnumpy() == expect5) + + output6 = div(x6, y6) + expect6 = np.divide(x6_np, y6_np) + diff6 = output6.asnumpy() - expect6 + error6 = np.ones(shape=expect6.shape) * 1.0e-5 + assert np.all(diff6 < error6) + assert output6.shape == expect6.shape + + output7 = div(x7, y7) + expect7 = np.divide(x7_np, y7_np).astype(np.int64) + assert np.all(output7.asnumpy() == expect7) + assert output7.shape == expect7.shape + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_mod(): + prop = 1 if np.random.random() < 0.5 else -1 + x0_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.float32) * prop + y0_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.float32) * prop + x1_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.float32) * prop + y1_np = np.random.randint(1, 100, (2, 1, 4, 4)).astype(np.float32) * prop + x2_np = np.random.randint(1, 100, (2, 1, 1, 4)).astype(np.float16) * prop + y2_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.float16) * prop + x3_np = np.random.randint(1, 100, 1).astype(np.float32) * prop + y3_np = np.random.randint(1, 100, 1).astype(np.float32) * prop + x4_np = np.array(768).astype(np.float32) * prop + y4_np = np.array(3072.5).astype(np.float32) * prop + x5_np = np.random.randint(1, 100, (2, 1, 1, 4)).astype(np.int32) * prop + y5_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.int32) * prop + x6_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.int32) * prop + y6_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.float32) * prop + x7_np = np.random.randint(1, 100, (2, 1, 1, 4)).astype(np.int64) * prop + y7_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.int64) * prop + + x0 = Tensor(x0_np) + y0 = Tensor(y0_np) + x1 = Tensor(x1_np) + y1 = Tensor(y1_np) + x2 = Tensor(x2_np) + y2 = Tensor(y2_np) + x3 = Tensor(x3_np) + y3 = Tensor(y3_np) + x4 = Tensor(x4_np) + y4 = Tensor(y4_np) + x5 = Tensor(x5_np) + y5 = Tensor(y5_np) + x6 = Tensor(x6_np) + y6 = Tensor(y6_np) + x7 = Tensor(x7_np) + y7 = Tensor(y7_np) + + context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + mod = ModNet() + output0 = mod(x0, y0) + expect0 = np.mod(x0_np, y0_np) + diff0 = output0.asnumpy() - expect0 + error0 = np.ones(shape=expect0.shape) * 1.0e-5 + assert np.all(diff0 < error0) + assert output0.shape == expect0.shape + + output1 = mod(x1, y1) + expect1 = np.mod(x1_np, y1_np) + diff1 = output1.asnumpy() - expect1 + error1 = np.ones(shape=expect1.shape) * 1.0e-5 + assert np.all(diff1 < error1) + assert output1.shape == expect1.shape + + output2 = mod(x2, y2) + expect2 = np.mod(x2_np, y2_np).astype(np.float16) + diff2 = output2.asnumpy() - expect2 + error2 = np.ones(shape=expect2.shape) * 1.0e-5 + assert np.all(diff2 < error2) + assert output2.shape == expect2.shape + + output3 = mod(x3, y3) + expect3 = np.mod(x3_np, y3_np) + diff3 = output3.asnumpy() - expect3 + error3 = np.ones(shape=expect3.shape) * 1.0e-5 + assert np.all(diff3 < error3) + assert output3.shape == expect3.shape + + output4 = mod(x4, y4) + expect4 = np.mod(x4_np, y4_np) + diff4 = output4.asnumpy() - expect4 + error4 = np.ones(shape=expect4.shape) * 1.0e-5 + assert np.all(diff4 < error4) + assert output4.shape == expect4.shape + + output5 = mod(x5, y5) + expect5 = np.mod(x5_np, y5_np) + assert np.all(output5.asnumpy() == expect5) + assert output5.shape == expect5.shape + + output6 = mod(x6, y6) + expect6 = np.mod(x6_np, y6_np) + diff6 = output6.asnumpy() - expect6 + error6 = np.ones(shape=expect6.shape) * 1.0e-5 + assert np.all(diff6 < error6) + assert output6.shape == expect6.shape + + output7 = mod(x7, y7) + expect7 = np.mod(x7_np, y7_np).astype(np.int64) + assert np.all(output7.asnumpy() == expect7) + assert output6.shape == expect6.shape + test_sub() +test_div() +test_mod() diff --git a/tests/st/ops/cpu/test_arithmetic_self_op.py b/tests/st/ops/cpu/test_arithmetic_self_op.py index d0e0da02c7..508f707859 100644 --- a/tests/st/ops/cpu/test_arithmetic_self_op.py +++ b/tests/st/ops/cpu/test_arithmetic_self_op.py @@ -32,6 +32,15 @@ class SquareNet(nn.Cell): return self.square(x) +class FloorNet(nn.Cell): + def __init__(self): + super(FloorNet, self).__init__() + self.floor = P.Floor() + + def construct(self, x): + return self.floor(x) + + @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard @@ -78,4 +87,26 @@ def test_square(): print(output) assert np.all(output.asnumpy() == expect_output) + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_floor(): + net = FloorNet() + + x = np.random.randn(3, 4).astype(np.float16) + x = x * 100 + output = net(Tensor(x)) + expect_output = np.floor(x).astype(np.float16) + print(output.asnumpy()) + assert np.all(output.asnumpy() == expect_output) + + x = np.random.randn(4, 3).astype(np.float32) + x = x * 100 + output = net(Tensor(x)) + expect_output = np.floor(x) + print(output.asnumpy()) + assert np.all(output.asnumpy() == expect_output) + test_square() +test_floor() diff --git a/tests/st/ops/cpu/test_minimum_op.py b/tests/st/ops/cpu/test_minimum_op.py new file mode 100644 index 0000000000..0009c5c07a --- /dev/null +++ b/tests/st/ops/cpu/test_minimum_op.py @@ -0,0 +1,185 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +from mindspore.common.tensor import Tensor +from mindspore.nn import Cell +from mindspore.ops import operations as P + + +class ConstScalarAndTensorMinimum(Cell): + def __init__(self): + super(ConstScalarAndTensorMinimum, self).__init__() + self.min = P.Minimum() + self.x = 20 + + def construct(self, y): + return self.min(self.x, y) + + +class TwoTensorsMinimum(Cell): + def __init__(self): + super(TwoTensorsMinimum, self).__init__() + self.min = P.Minimum() + + def construct(self, x, y): + return self.min(x, y) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_minimum_constScalar_tensor_int(): + x = Tensor(np.array([[2, 3, 4], [100, 200, 300]]).astype(np.int32)) + expect = [[2, 3, 4], [20, 20, 20]] + + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + min_op = ConstScalarAndTensorMinimum() + output = min_op(x) + assert np.all(output.asnumpy() == expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_minimum_two_tensors_Not_Broadcast_int(): + prop = 100 if np.random.random() > 0.5 else -100 + x = np.random.randn(3, 4, 5).astype(np.int32) * prop + y = np.random.randn(3, 4, 5).astype(np.int32) * prop + expect = np.minimum(x, y).astype(np.int32) + + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + min_op = TwoTensorsMinimum() + output = min_op(Tensor(x), Tensor(y)) + assert np.all(output.asnumpy() == expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_minimum_two_tensors_Broadcast_int(): + prop = 100 if np.random.random() > 0.5 else -100 + x = np.random.randn(3, 4, 5).astype(np.int32) * prop + y = np.random.randn(3, 1, 1).astype(np.int32) * prop + expect = np.minimum(x, y).astype(np.int32) + + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + min_op = TwoTensorsMinimum() + output = min_op(Tensor(x), Tensor(y)) + assert np.all(output.asnumpy() == expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_minimum_two_tensors_Broadcast_oneDimension_int(): + prop = 100 if np.random.random() > 0.5 else -100 + x = np.random.randn(3).astype(np.int32) * prop + y = np.random.randn(3).astype(np.int32) * prop + expect = np.minimum(x, y).astype(np.int32) + + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + min_op = TwoTensorsMinimum() + output = min_op(Tensor(x), Tensor(y)) + assert np.all(output.asnumpy() == expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_minimum_two_tensors_notBroadcast_all_oneDimension_int(): + x = Tensor(np.array([[2]]).astype(np.int32)) + y = Tensor(np.array([[100]]).astype(np.int32)) + expect = [[2]] + + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + min_op = TwoTensorsMinimum() + output = min_op(x, y) + assert np.all(output.asnumpy() == expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_minimum_two_tensors_notBroadcast_float32(): + prop = 100 if np.random.random() > 0.5 else -100 + x = np.random.randn(3, 4, 5).astype(np.float32) * prop + y = np.random.randn(3, 4, 5).astype(np.float32) * prop + expect = np.minimum(x, y).astype(np.float32) + error = np.ones(shape=expect.shape) * 1.0e-5 + + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + min_op = TwoTensorsMinimum() + output = min_op(Tensor(x), Tensor(y)) + diff = output.asnumpy() - expect + assert np.all(np.abs(diff) < error) + assert output.shape == expect.shape + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_minimum_two_tensors_notBroadcast_float16(): + prop = 100 if np.random.random() > 0.5 else -100 + x = np.random.randn(3, 4, 5).astype(np.float16) * prop + y = np.random.randn(3, 4, 5).astype(np.float16) * prop + expect = np.minimum(x, y).astype(np.float16) + error = np.ones(shape=expect.shape) * 1.0e-5 + + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + min_op = TwoTensorsMinimum() + output = min_op(Tensor(x), Tensor(y)) + diff = output.asnumpy() - expect + assert np.all(np.abs(diff) < error) + assert output.shape == expect.shape + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_minimum_two_tensors_Broadcast_float16(): + prop = 100 if np.random.random() > 0.5 else -100 + x = np.random.randn(3, 4, 5).astype(np.float16) * prop + y = np.random.randn(3, 4, 1).astype(np.float16) * prop + expect = np.minimum(x, y).astype(np.float16) + error = np.ones(shape=expect.shape) * 1.0e-5 + + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + min_op = TwoTensorsMinimum() + output = min_op(Tensor(x), Tensor(y)) + diff = output.asnumpy() - expect + assert np.all(np.abs(diff) < error) + assert output.shape == expect.shape + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_minimum_two_tensors_notBroadcast_float64(): + prop = 100 if np.random.random() > 0.5 else -100 + x = np.random.randn(3, 4, 1).astype(np.float64) * prop + y = np.random.randn(3, 4, 5).astype(np.float64) * prop + expect = np.minimum(x, y).astype(np.float64) + error = np.ones(shape=expect.shape) * 1.0e-5 + + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + min_op = TwoTensorsMinimum() + output = min_op(Tensor(x), Tensor(y)) + diff = output.asnumpy() - expect + assert np.all(np.abs(diff) < error) + assert output.shape == expect.shape