!2976 Add input infos to output for apply op to match the tbe registered info
Merge pull request !2976 from YuJianfeng/input2outputpull/2976/MERGE
commit
c8d7ac13aa
@ -0,0 +1,115 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pre_activate/ascend/ir_fusion/add_input_to_output.h"
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "pre_activate/ascend/ir_fusion/input_to_output_registry.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "kernel/oplib/oplib.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
void GetInputOrOutputNames(const CNodePtr &cnode, const std::string &attr_name, std::vector<std::string> *names_vec) {
|
||||
MS_EXCEPTION_IF_NULL(names_vec);
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
ValuePtr names_value = primitive->GetAttr(attr_name);
|
||||
if (names_value == nullptr) {
|
||||
return;
|
||||
}
|
||||
*names_vec = GetValue<std::vector<std::string>>(names_value);
|
||||
}
|
||||
|
||||
void AddOutputs(const CNodePtr &cnode, const std::vector<size_t> &input_indices) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::vector<std::string> input_names_vec;
|
||||
GetInputOrOutputNames(cnode, kAttrInputNames, &input_names_vec);
|
||||
std::vector<std::string> output_names_vec;
|
||||
GetInputOrOutputNames(cnode, kAttrOutputNames, &output_names_vec);
|
||||
AbstractBasePtrList abstract_list;
|
||||
auto origin_abstract = cnode->abstract();
|
||||
MS_EXCEPTION_IF_NULL(origin_abstract);
|
||||
if (origin_abstract->isa<abstract::AbstractTuple>()) {
|
||||
auto origin_abstract_tuple = dyn_cast<abstract::AbstractTuple>(origin_abstract);
|
||||
MS_EXCEPTION_IF_NULL(origin_abstract_tuple);
|
||||
AbstractBasePtrList origin_abstract_list = origin_abstract_tuple->elements();
|
||||
(void)std::copy(origin_abstract_list.begin(), origin_abstract_list.end(), std::back_inserter(abstract_list));
|
||||
} else {
|
||||
abstract_list.emplace_back(origin_abstract);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < input_indices.size(); ++i) {
|
||||
size_t index = input_indices[i];
|
||||
if (index + 1 >= cnode->inputs().size()) {
|
||||
MS_LOG(INFO) << "The input index " << index << " for converting to output is out of range, "
|
||||
<< "node: " << cnode->DebugString();
|
||||
continue;
|
||||
}
|
||||
auto node_to_output = cnode->input(index + 1);
|
||||
MS_EXCEPTION_IF_NULL(node_to_output);
|
||||
abstract_list.emplace_back(node_to_output->abstract());
|
||||
if (!input_names_vec.empty() && !output_names_vec.empty() && index < input_names_vec.size()) {
|
||||
output_names_vec.emplace_back(input_names_vec[index]);
|
||||
}
|
||||
}
|
||||
if (!output_names_vec.empty()) {
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names_vec), cnode);
|
||||
}
|
||||
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
cnode->set_abstract(abstract_tuple);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const AnfNodePtr AddInputToOutput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::string op_name = AnfAlgo::GetCNodeName(cnode);
|
||||
InputToOutputRegister reg;
|
||||
if (!InputToOutputRegistry::Instance().GetRegisterByOpName(op_name, ®)) {
|
||||
return nullptr;
|
||||
}
|
||||
int output_num = op_finder_->GetOpRegisteredOutputNum(op_name);
|
||||
// No need add output when it is not a tbe op.
|
||||
if (output_num == -1) {
|
||||
return nullptr;
|
||||
}
|
||||
// No need add output if the output num matches the registered output num for tbe.
|
||||
if (AnfAlgo::GetOutputTensorNum(cnode) >= IntToSize(output_num)) {
|
||||
return nullptr;
|
||||
}
|
||||
bool is_origin_tuple_output = AnfAlgo::IsTupleOutput(cnode);
|
||||
AddOutputs(cnode, reg.input_indices());
|
||||
// No need to create tuple_getitem if the origin output is a tuple because there has already been some tuple_getitems
|
||||
// pointed to the outputs.
|
||||
if (is_origin_tuple_output) {
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<AnfNodePtr> new_outputs;
|
||||
auto new_abstract_tuple = dyn_cast<abstract::AbstractTuple>(cnode->abstract());
|
||||
MS_EXCEPTION_IF_NULL(new_abstract_tuple);
|
||||
CreateMultipleOutputsOfAnfNode(func_graph, cnode, new_abstract_tuple->size(), &new_outputs);
|
||||
if (new_outputs.size() != new_abstract_tuple->size()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to create outputs of " << cnode->DebugString();
|
||||
}
|
||||
return new_outputs[0];
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
@ -0,0 +1,39 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "pre_activate/common/optimizer.h"
|
||||
#include "pre_activate/ascend/ascend_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class AddInputToOutput : public PatternProcessPass {
|
||||
public:
|
||||
explicit AddInputToOutput(bool multigraph = true)
|
||||
: PatternProcessPass("add_input_to_output", multigraph), op_finder_(std::make_shared<OpFinder>()) {}
|
||||
~AddInputToOutput() override = default;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
OpFinderPtr op_finder_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADD_INPUT_TO_OUTPUT_H_
|
@ -0,0 +1,122 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pre_activate/ascend/ir_fusion/input_to_output_registry.h"
|
||||
#include <utility>
|
||||
#include "utils/utils.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
bool ApplyRMSPropPreCheck(const CNodePtr &node) {
|
||||
return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32);
|
||||
}
|
||||
|
||||
bool FusedMulApplyMomentumPreCheck(const CNodePtr &node) {
|
||||
TypeId data_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
|
||||
return !(data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat16);
|
||||
}
|
||||
|
||||
bool SparseApplyRMSPropPreCheck(const CNodePtr &node) {
|
||||
return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32);
|
||||
}
|
||||
|
||||
bool ApplyAdagradV2PreCheck(const CNodePtr &node) {
|
||||
TypeId data_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
|
||||
return !(data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat16);
|
||||
}
|
||||
|
||||
bool ApplyKerasMomentumPreCheck(const CNodePtr &node) {
|
||||
TypeId data_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
|
||||
return !(data_type != kNumberTypeFloat32 && data_type != kNumberTypeFloat16);
|
||||
}
|
||||
|
||||
bool SparseApplyFtrlPreCheck(const CNodePtr &node) {
|
||||
return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32);
|
||||
}
|
||||
|
||||
bool SparseApplyFtrlV2PreCheck(const CNodePtr &node) {
|
||||
return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32);
|
||||
}
|
||||
|
||||
bool SparseApplyAdagradV2PreCheck(const CNodePtr &node) {
|
||||
return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32);
|
||||
}
|
||||
|
||||
bool SparseApplyAdadeltaPreCheck(const CNodePtr &node) {
|
||||
return !(AnfAlgo::GetPrevNodeOutputInferDataType(node, 0) != kNumberTypeFloat32);
|
||||
}
|
||||
} // namespace
|
||||
InputToOutputRegistry::InputToOutputRegistry() {
|
||||
Register(kApplyRMSPropOpName, {1, 2}, ApplyRMSPropPreCheck);
|
||||
Register(kFusedMulApplyMomentumOpName, {1}, FusedMulApplyMomentumPreCheck);
|
||||
Register(kApplyAdagradOpName, {1});
|
||||
Register(kApplyAdagradDAName, {1, 2});
|
||||
Register(kApplyAdadeltaOpName, {1, 2});
|
||||
Register(kApplyPowerSignOpName, {1});
|
||||
Register(kApplyProximalAdagradOpName, {1});
|
||||
Register(kApplyAdaMaxOpName, {1, 2});
|
||||
Register(kApplyAdagradV2OpName, {1}, ApplyAdagradV2PreCheck);
|
||||
Register(kApplyKerasMomentumOpName, {1}, ApplyKerasMomentumPreCheck);
|
||||
Register(kSparseApplyFtrlOpName, {1, 2}, SparseApplyFtrlPreCheck);
|
||||
Register(kSparseApplyFtrlV2OpName, {1, 2}, SparseApplyFtrlV2PreCheck);
|
||||
Register(kSparseApplyAdagradV2OpName, {1}, SparseApplyAdagradV2PreCheck);
|
||||
Register(kSparseApplyProximalAdagradOpName, {1});
|
||||
Register(kSparseApplyAdagradOpName, {1});
|
||||
Register(kApplyFtrlV2OpName, {1, 2});
|
||||
Register(kApplyMomentumOpName, {1});
|
||||
Register(kApplyFtrlOpName, {1, 2});
|
||||
Register(kApplyAdamOpName, {1, 2});
|
||||
Register(kApplyCenteredRMSPropOpName, {1, 2, 3});
|
||||
Register(kApplyAddSignOpName, {1});
|
||||
Register(kSparseApplyRMSPropOpName, {1, 2}, SparseApplyRMSPropPreCheck);
|
||||
Register(kSparseApplyAdadeltaOpName, {1, 2}, SparseApplyAdadeltaPreCheck);
|
||||
Register(kApplyAdamWithAmsgradOpName, {1, 2});
|
||||
}
|
||||
|
||||
InputToOutputRegistry &InputToOutputRegistry::Instance() {
|
||||
static InputToOutputRegistry instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void InputToOutputRegistry::Register(const InputToOutputRegister ®) {
|
||||
auto op_name = reg.op_name();
|
||||
if (op_input_to_output_map_.find(op_name) == op_input_to_output_map_.end()) {
|
||||
(void)op_input_to_output_map_.insert(make_pair(op_name, reg));
|
||||
MS_LOG(DEBUG) << op_name << " input2output register successfully!";
|
||||
}
|
||||
}
|
||||
|
||||
void InputToOutputRegistry::Register(const std::string &op_name, const std::vector<size_t> &input_indices,
|
||||
const PreCheckFunc &pre_check_func) {
|
||||
if (op_input_to_output_map_.find(op_name) == op_input_to_output_map_.end()) {
|
||||
InputToOutputRegister reg(op_name, pre_check_func);
|
||||
reg.set_input_indices(input_indices);
|
||||
(void)op_input_to_output_map_.insert(make_pair(op_name, reg));
|
||||
MS_LOG(DEBUG) << op_name << " input2output register successfully!";
|
||||
}
|
||||
}
|
||||
|
||||
bool InputToOutputRegistry::GetRegisterByOpName(const std::string &op_name, InputToOutputRegister *reg) const {
|
||||
if (op_input_to_output_map_.find(op_name) != op_input_to_output_map_.end()) {
|
||||
*reg = op_input_to_output_map_.at(op_name);
|
||||
MS_LOG(DEBUG) << op_name << " input2output find in registry.";
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
@ -0,0 +1,64 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_IR_FUSION_INPUT_TO_OUTPUT_REGISTRY_H_
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_IR_FUSION_INPUT_TO_OUTPUT_REGISTRY_H_
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "ir/anf.h"
|
||||
#include "common/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
using PreCheckFunc = std::function<bool(const CNodePtr &node)>;
|
||||
class InputToOutputRegister {
|
||||
public:
|
||||
explicit InputToOutputRegister(
|
||||
const std::string &op_name = "", const PreCheckFunc &pre_check_func = [](const CNodePtr &node) { return true; })
|
||||
: op_name_(op_name), pre_check_func_(pre_check_func) {}
|
||||
virtual ~InputToOutputRegister() = default;
|
||||
|
||||
void set_input_indices(const std::vector<size_t> &input_indices) { input_indices_ = input_indices; }
|
||||
|
||||
const std::vector<size_t> &input_indices() const { return input_indices_; }
|
||||
const std::string &op_name() const { return op_name_; }
|
||||
|
||||
private:
|
||||
std::string op_name_;
|
||||
std::vector<size_t> input_indices_;
|
||||
PreCheckFunc pre_check_func_;
|
||||
};
|
||||
|
||||
class InputToOutputRegistry {
|
||||
public:
|
||||
static InputToOutputRegistry &Instance();
|
||||
void Register(const InputToOutputRegister ®);
|
||||
void Register(
|
||||
const std::string &op_name, const std::vector<size_t> &input_indices,
|
||||
const PreCheckFunc &pre_check_func = [](const CNodePtr &node) { return true; });
|
||||
bool GetRegisterByOpName(const std::string &op_name, InputToOutputRegister *reg) const;
|
||||
|
||||
private:
|
||||
InputToOutputRegistry();
|
||||
~InputToOutputRegistry() = default;
|
||||
DISABLE_COPY_AND_ASSIGN(InputToOutputRegistry)
|
||||
std::unordered_map<std::string, InputToOutputRegister> op_input_to_output_map_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_IR_FUSION_INPUT_TO_OUTPUT_REGISTRY_H_
|
@ -0,0 +1,74 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/backend_common_test.h"
|
||||
#include "common/py_func_graph_fetcher.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
|
||||
#define private public
|
||||
#define protected public
|
||||
#include "pre_activate/ascend/ir_fusion/add_input_to_output.h"
|
||||
#undef private
|
||||
#undef protected
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class TestHWAddInputToOutput : public BackendCommon {
|
||||
public:
|
||||
TestHWAddInputToOutput() : getPyFun_("gtest_input.pre_activate.add_input_to_output_test", true) {}
|
||||
~TestHWAddInputToOutput() override = default;
|
||||
|
||||
public:
|
||||
UT::PyFuncGraphFetcher getPyFun_;
|
||||
};
|
||||
|
||||
class MockOpFinder : public OpFinder {
|
||||
public:
|
||||
MockOpFinder() = default;
|
||||
~MockOpFinder() override = default;
|
||||
int GetOpRegisteredOutputNum(const std::string &op_name) override { return 2; }
|
||||
};
|
||||
|
||||
TEST_F(TestHWAddInputToOutput, test_add_input_to_output) {
|
||||
FuncGraphPtr g = getPyFun_.CallAndParseRet("test_add_input_to_output", "before");
|
||||
EXPECT_NE(g, nullptr);
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 5; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto kg = GetKernelGraph(g, args_spec_list);
|
||||
EXPECT_NE(kg, nullptr);
|
||||
auto ret = kg->get_return();
|
||||
EXPECT_NE(ret, nullptr);
|
||||
auto make_tuple = ret->input(1);
|
||||
EXPECT_NE(make_tuple, nullptr);
|
||||
auto momentum = make_tuple->cast<CNodePtr>()->input(1);
|
||||
EXPECT_NE(momentum, nullptr);
|
||||
EXPECT_NE(momentum->abstract(), nullptr);
|
||||
EXPECT_FALSE(momentum->abstract()->isa<abstract::AbstractTuple>());
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
auto pass = std::make_shared<opt::AddInputToOutput>();
|
||||
pass->op_finder_ = std::make_shared<MockOpFinder>();
|
||||
pm->AddPass(pass);
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(kg);
|
||||
EXPECT_TRUE(momentum->abstract()->isa<abstract::AbstractTuple>());
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
@ -0,0 +1,39 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
ApplyMomentum = P.ApplyMomentum()
|
||||
|
||||
|
||||
class FnDict:
|
||||
def __init__(self):
|
||||
self.fnDict = {}
|
||||
|
||||
def __call__(self, fn):
|
||||
self.fnDict[fn.__name__] = fn
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self.fnDict[name]
|
||||
|
||||
|
||||
def test_add_input_to_output(tag):
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(input0, input1, input2, input3, input4):
|
||||
return ApplyMomentum(input0, input1, input2, input3, input4)
|
||||
|
||||
return fns[tag]
|
Loading…
Reference in new issue