/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include "common/common_test.h" #include "frontend/operator/cc_implementations.h" namespace mindspore { namespace prim { class TestImplementations : public UT::Common { public: TestImplementations() {} virtual void SetUp() {} }; TEST_F(TestImplementations, ScalarAddTest) { ValuePtrList list; list.push_back(MakeValue(1)); list.push_back(MakeValue(2)); ASSERT_EQ(ScalarAdd(list)->cast()->value(), 3); list.clear(); list.push_back(MakeValue(1.0f)); list.push_back(MakeValue(1.5f)); ASSERT_EQ(ScalarAdd(list)->cast()->value(), 2.5f); list.clear(); list.push_back(MakeValue(3.0)); list.push_back(MakeValue(0.5)); ASSERT_EQ(ScalarAdd(list)->cast()->value(), 3.5); list.clear(); list.push_back(MakeValue(INT32_MAX)); list.push_back(MakeValue(2)); try { ScalarAdd(list); FAIL(); } catch (std::runtime_error const &err) { ASSERT_TRUE(std::string(err.what()).find("Overflow of the sum of two signed number") != std::string::npos); } list.clear(); list.push_back(MakeValue(INT32_MIN)); list.push_back(MakeValue(-1)); try { ScalarAdd(list); FAIL(); } catch (std::runtime_error const &err) { ASSERT_TRUE(std::string(err.what()).find("Overflow of the sum of two signed number") != std::string::npos); } list.clear(); } TEST_F(TestImplementations, ScalarSubTest) { ValuePtrList list; list.push_back(MakeValue(1)); list.push_back(MakeValue(3)); ASSERT_EQ(ScalarSub(list)->cast()->value(), -2); list.clear(); list.push_back(MakeValue(1.0f)); list.push_back(MakeValue(1.5f)); ASSERT_EQ(ScalarSub(list)->cast()->value(), -0.5f); list.clear(); list.push_back(MakeValue(3.0)); list.push_back(MakeValue(0.5)); ASSERT_EQ(ScalarSub(list)->cast()->value(), 2.5); list.clear(); list.push_back(MakeValue(INT32_MAX)); list.push_back(MakeValue(-1)); try { ScalarSub(list); FAIL(); } catch (std::runtime_error const &err) { ASSERT_TRUE(std::string(err.what()).find("Overflow of the sub of two signed number") != std::string::npos); } list.clear(); list.push_back(MakeValue(INT32_MIN)); list.push_back(MakeValue(1)); try { ScalarSub(list); FAIL(); } catch (std::runtime_error const &err) { ASSERT_TRUE(std::string(err.what()).find("Overflow of the sub of two signed number") != std::string::npos); } list.clear(); } TEST_F(TestImplementations, ScalarMulTest) { ValuePtrList list; list.push_back(MakeValue(2)); list.push_back(MakeValue(3)); ASSERT_EQ(ScalarMul(list)->cast()->value(), 6); list.clear(); list.push_back(MakeValue(2.0f)); list.push_back(MakeValue(1.5f)); ASSERT_EQ(ScalarMul(list)->cast()->value(), 3.0f); list.clear(); list.push_back(MakeValue(-2.0)); list.push_back(MakeValue(-4.0)); ASSERT_EQ(ScalarMul(list)->cast()->value(), 8.0); list.clear(); list.push_back(MakeValue(10)); list.push_back(MakeValue(INT32_MAX)); try { ScalarMul(list); FAIL(); } catch (std::runtime_error const &err) { ASSERT_TRUE(std::string(err.what()).find("Overflow of the mul of two signed number") != std::string::npos); } list.clear(); list.push_back(MakeValue(INT32_MIN)); list.push_back(MakeValue(-1)); try { ScalarMul(list); FAIL(); } catch (std::runtime_error const &err) { ASSERT_TRUE(std::string(err.what()).find("Overflow of the mul of two signed number") != std::string::npos); } list.clear(); list.push_back(MakeValue(-2)); list.push_back(MakeValue(INT32_MAX)); try { ScalarMul(list); FAIL(); } catch (std::runtime_error const &err) { ASSERT_TRUE(std::string(err.what()).find("Overflow of the mul of two signed number") != std::string::npos); } list.clear(); list.push_back(MakeValue(2)); list.push_back(MakeValue(INT32_MIN)); try { ScalarMul(list); FAIL(); } catch (std::runtime_error const &err) { ASSERT_TRUE(std::string(err.what()).find("Overflow of the mul of two signed number") != std::string::npos); } list.clear(); list.push_back(MakeValue(0)); list.push_back(MakeValue(INT32_MIN)); ASSERT_EQ(ScalarDiv(list)->cast()->value(), 0); list.clear(); } TEST_F(TestImplementations, ScalarDivTest) { ValuePtrList list; list.push_back(MakeValue(6)); list.push_back(MakeValue(3)); ASSERT_EQ(ScalarDiv(list)->cast()->value(), 2); list.clear(); list.push_back(MakeValue(3.0f)); list.push_back(MakeValue(1.5f)); ASSERT_EQ(ScalarDiv(list)->cast()->value(), 2.0f); list.clear(); list.push_back(MakeValue(-4.0)); list.push_back(MakeValue(2.0)); ASSERT_EQ(ScalarDiv(list)->cast()->value(), -2.0); list.clear(); list.push_back(MakeValue(INT32_MAX)); list.push_back(MakeValue(0)); try { ScalarDiv(list); FAIL(); } catch (std::runtime_error const &err) { ASSERT_TRUE(std::string(err.what()).find("Divisor could not be zero") != std::string::npos); } list.clear(); list.push_back(MakeValue(INT32_MIN)); list.push_back(MakeValue(-1)); try { ScalarDiv(list); FAIL(); } catch (std::runtime_error const &err) { ASSERT_TRUE(std::string(err.what()).find("Overflow of the div of two signed number") != std::string::npos); } list.clear(); list.push_back(MakeValue(-1)); list.push_back(MakeValue(INT32_MIN)); ASSERT_EQ(ScalarDiv(list)->cast()->value(), 0); list.clear(); } TEST_F(TestImplementations, ScalarModTest) { ValuePtrList list; list.push_back(MakeValue(7)); list.push_back(MakeValue(3)); ASSERT_EQ(ScalarMod(list)->cast()->value(), 1); list.clear(); list.push_back(MakeValue(-8)); list.push_back(MakeValue(3)); ASSERT_EQ(ScalarMod(list)->cast()->value(), -2); list.clear(); list.push_back(MakeValue(-9)); list.push_back(MakeValue(2)); ASSERT_EQ(ScalarMod(list)->cast()->value(), -1); list.clear(); list.push_back(MakeValue(INT32_MIN)); list.push_back(MakeValue(0)); try { ScalarMod(list); FAIL(); } catch (std::runtime_error const &err) { ASSERT_TRUE(std::string(err.what()).find("Could not mod to zero") != std::string::npos); } list.clear(); list.push_back(MakeValue(INT32_MIN)); list.push_back(MakeValue(-1)); try { ScalarMod(list); FAIL(); } catch (std::runtime_error const &err) { ASSERT_TRUE(std::string(err.what()).find("Overflow of the mod of two signed number") != std::string::npos); } list.clear(); } TEST_F(TestImplementations, ScalarUAddTest) { ValuePtrList list; list.push_back(MakeValue((uint32_t)1)); ASSERT_EQ(ScalarUAdd(list)->cast()->value(), 1); list.clear(); } TEST_F(TestImplementations, ScalarLogTest) { ValuePtrList list; list.push_back(MakeValue(static_cast(7.3890560989306495))); ASSERT_EQ(ScalarLog(list)->cast()->value(), 2.0); list.clear(); } TEST_F(TestImplementations, ScalarUSubTest) { ValuePtrList list; list.push_back(MakeValue(1)); ASSERT_EQ(ScalarUSub(list)->cast()->value(), -1); list.clear(); } TEST_F(TestImplementations, ScalarEqTest) { ValuePtrList list; list.push_back(MakeValue(1.0f)); list.push_back(MakeValue(1.0f)); ASSERT_EQ(ScalarEq(list)->cast()->value(), true); list.clear(); list.push_back(MakeValue(1.0f)); list.push_back(MakeValue(-1.0f)); ASSERT_EQ(ScalarEq(list)->cast()->value(), false); list.clear(); list.push_back(MakeValue(1.0f)); list.push_back(MakeValue(1.0)); ASSERT_EQ(ScalarEq(list)->cast()->value(), true); list.clear(); list.push_back(MakeValue(1.0)); list.push_back(MakeValue(1.0)); ASSERT_EQ(ScalarEq(list)->cast()->value(), true); list.clear(); } TEST_F(TestImplementations, ScalarLtTest) { ValuePtrList list; list.push_back(MakeValue(1.0f)); list.push_back(MakeValue(1.0f)); ASSERT_EQ(ScalarLt(list)->cast()->value(), false); list.clear(); list.push_back(MakeValue(1.0f)); list.push_back(MakeValue(-1.0f)); ASSERT_EQ(ScalarLt(list)->cast()->value(), false); list.clear(); list.push_back(MakeValue(1.0f)); list.push_back(MakeValue(2.5)); ASSERT_EQ(ScalarLt(list)->cast()->value(), true); list.clear(); list.push_back(MakeValue(2.5)); list.push_back(MakeValue(3.0)); ASSERT_EQ(ScalarLt(list)->cast()->value(), true); list.clear(); } TEST_F(TestImplementations, ScalarGtTest) { ValuePtrList list; list.push_back(MakeValue(1.0f)); list.push_back(MakeValue(2.0f)); ASSERT_EQ(ScalarGt(list)->cast()->value(), false); list.clear(); list.push_back(MakeValue(2.0f)); list.push_back(MakeValue(-1.0f)); ASSERT_EQ(ScalarGt(list)->cast()->value(), true); list.clear(); list.push_back(MakeValue(2.0f)); list.push_back(MakeValue(2.0)); ASSERT_EQ(ScalarGt(list)->cast()->value(), false); list.clear(); list.push_back(MakeValue(2.5)); list.push_back(MakeValue(2.0)); ASSERT_EQ(ScalarGt(list)->cast()->value(), true); list.clear(); } TEST_F(TestImplementations, ScalarNeTest) { ValuePtrList list; list.push_back(MakeValue(1.0f)); list.push_back(MakeValue(1.0f)); ASSERT_EQ(ScalarNe(list)->cast()->value(), false); list.clear(); list.push_back(MakeValue(1.0f)); list.push_back(MakeValue(-1.0f)); ASSERT_EQ(ScalarNe(list)->cast()->value(), true); list.clear(); list.push_back(MakeValue(1.0f)); list.push_back(MakeValue(2.0)); ASSERT_EQ(ScalarNe(list)->cast()->value(), true); list.clear(); list.push_back(MakeValue(2.0)); list.push_back(MakeValue(2.0)); ASSERT_EQ(ScalarNe(list)->cast()->value(), false); list.clear(); } TEST_F(TestImplementations, ScalarLeTest) { ValuePtrList list; list.push_back(MakeValue(1.0f)); list.push_back(MakeValue(1.0f)); ASSERT_EQ(ScalarLe(list)->cast()->value(), true); list.clear(); list.push_back(MakeValue(1.0f)); list.push_back(MakeValue(-1.0f)); ASSERT_EQ(ScalarLe(list)->cast()->value(), false); list.clear(); list.push_back(MakeValue(1.0f)); list.push_back(MakeValue(2.0)); ASSERT_EQ(ScalarLe(list)->cast()->value(), true); list.clear(); list.push_back(MakeValue(6.0)); list.push_back(MakeValue(-1.0f)); ASSERT_EQ(ScalarLe(list)->cast()->value(), false); list.clear(); } TEST_F(TestImplementations, ScalarGeTest) { ValuePtrList list; list.push_back(MakeValue(1.0f)); list.push_back(MakeValue(1.0f)); ASSERT_EQ(ScalarGe(list)->cast()->value(), true); list.clear(); list.push_back(MakeValue(1.0f)); list.push_back(MakeValue(-1.0f)); ASSERT_EQ(ScalarGe(list)->cast()->value(), true); list.clear(); list.push_back(MakeValue(1.0f)); list.push_back(MakeValue(2.0)); ASSERT_EQ(ScalarGe(list)->cast()->value(), false); list.clear(); list.push_back(MakeValue(6.0)); list.push_back(MakeValue(-1.0f)); ASSERT_EQ(ScalarGe(list)->cast()->value(), true); list.clear(); } TEST_F(TestImplementations, BoolNotTest) { ValuePtrList list; list.push_back(MakeValue(true)); ASSERT_EQ(BoolNot(list)->cast()->value(), false); list.clear(); list.push_back(MakeValue(false)); ASSERT_EQ(BoolNot(list)->cast()->value(), true); list.clear(); } TEST_F(TestImplementations, BoolAndTest) { ValuePtrList list; list.push_back(MakeValue(true)); list.push_back(MakeValue(false)); ASSERT_EQ(BoolAnd(list)->cast()->value(), false); list.clear(); list.push_back(MakeValue(true)); list.push_back(MakeValue(true)); ASSERT_EQ(BoolAnd(list)->cast()->value(), true); list.clear(); list.push_back(MakeValue(false)); list.push_back(MakeValue(false)); ASSERT_EQ(BoolAnd(list)->cast()->value(), false); list.clear(); } TEST_F(TestImplementations, BoolOrTest) { ValuePtrList list; list.push_back(MakeValue(true)); list.push_back(MakeValue(false)); ASSERT_EQ(BoolOr(list)->cast()->value(), true); list.clear(); list.push_back(MakeValue(true)); list.push_back(MakeValue(true)); ASSERT_EQ(BoolOr(list)->cast()->value(), true); list.clear(); list.push_back(MakeValue(false)); list.push_back(MakeValue(false)); ASSERT_EQ(BoolOr(list)->cast()->value(), false); list.clear(); } TEST_F(TestImplementations, BoolEqTest) { ValuePtrList list; list.push_back(MakeValue(true)); list.push_back(MakeValue(false)); ASSERT_EQ(BoolEq(list)->cast()->value(), false); list.clear(); list.push_back(MakeValue(true)); list.push_back(MakeValue(true)); ASSERT_EQ(BoolEq(list)->cast()->value(), true); list.clear(); list.push_back(MakeValue(false)); list.push_back(MakeValue(false)); ASSERT_EQ(BoolEq(list)->cast()->value(), true); list.clear(); } } // namespace prim } // namespace mindspore