You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/mindspore/ccsrc/frontend/parallel/dynamic_creator.h

187 lines
5.2 KiB

/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_DYNAMIC_CREATOR_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_DYNAMIC_CREATOR_H_
#include <map>
#include <memory>
#include <string>
#include <utility>
#include "frontend/parallel/ops_info/ops_info_head_files.h"
#include "frontend/parallel/step_parallel.h"
namespace mindspore {
namespace parallel {
#define REGISTER(className) \
OperatorInfoPtr objectCreator##className(std::string name, Shapes in, Shapes out, PrimitiveAttrs &attrs) { \
return std::make_shared<className>(name, in, out, attrs); \
} \
RegisterAction className##Register(#className, (CreatFn)objectCreator##className);
typedef OperatorInfoPtr (*CreatFn)(const std::string &name, const Shapes &shape_in, const Shapes shape_out,
const PrimitiveAttrs &attrs);
class DynCreator {
public:
~DynCreator() = default;
// creat static singleton dyn_creator instance
static DynCreator &Instance() {
static DynCreator fac = DynCreator();
return fac;
}
// register
void Regist(std::string name, CreatFn func) { (void)Function_map_.insert(std::make_pair(name, func)); }
// creator
OperatorInfoPtr Creat(const std::string &name, const Shapes &shape_in, const Shapes &shape_out,
const PrimitiveAttrs &attrs, size_t count) {
std::string op_name = name + std::to_string(count);
auto iter = Function_map_.find(name);
if (iter == Function_map_.end()) {
MS_LOG(INFO) << name << " is not register yet";
return nullptr;
}
return iter->second(op_name, shape_in, shape_out, attrs);
}
private:
DynCreator() = default;
std::map<std::string, CreatFn> Function_map_;
};
class RegisterAction {
public:
RegisterAction(const std::string &name, CreatFn creatfn) : name_(name) {
DynCreator::Instance().Regist(name, creatfn);
}
~RegisterAction() = default;
private:
std::string name_;
};
// operator register
REGISTER(MatMulInfo);
REGISTER(GeluInfo);
REGISTER(VirtualDatasetInfo);
REGISTER(BatchParallelInfo);
REGISTER(TanhInfo);
REGISTER(SoftmaxInfo);
REGISTER(LogSoftmaxInfo);
REGISTER(ActivationInfo);
REGISTER(SoftmaxCrossEntropyWithLogitsInfo);
REGISTER(SubInfo);
REGISTER(TensorAddInfo);
REGISTER(BiasAddInfo);
REGISTER(MulInfo);
REGISTER(DivInfo);
REGISTER(ModInfo);
REGISTER(RealDivInfo);
REGISTER(PowInfo);
REGISTER(ExpInfo);
REGISTER(OneHotInfo);
REGISTER(EqualInfo);
REGISTER(NotEqualInfo);
REGISTER(LogInfo);
REGISTER(CosInfo);
REGISTER(ACosInfo);
REGISTER(LogicalNotInfo);
REGISTER(L2NormalizeInfo);
REGISTER(LayerNormInfo);
REGISTER(ReduceMaxInfo);
REGISTER(ArgMaxWithValueInfo);
REGISTER(ArgMinWithValueInfo);
REGISTER(ReduceMeanInfo);
REGISTER(ReduceSumInfo);
REGISTER(ReduceMinInfo);
REGISTER(TransposeInfo);
REGISTER(PReLUInfo);
REGISTER(DropoutDoMaskInfo);
REGISTER(ReshapeInfo);
REGISTER(FloorDivInfo);
REGISTER(MaximumInfo);
REGISTER(MinimumInfo);
REGISTER(CastInfo);
REGISTER(GreaterInfo);
REGISTER(GreaterEqualInfo);
REGISTER(LessEqualInfo);
REGISTER(LessInfo);
REGISTER(ApproximateEqualInfo);
REGISTER(SparseSoftmaxCrossEntropyWithLogitsInfo);
REGISTER(AssignSubInfo);
REGISTER(FloorModInfo);
REGISTER(AssignInfo);
REGISTER(AssignAddInfo);
REGISTER(Atan2Info);
REGISTER(DivNoNanInfo);
REGISTER(LogicalAndInfo);
REGISTER(LogicalOrInfo);
REGISTER(EluInfo);
REGISTER(ReLUInfo);
REGISTER(ReLU6Info);
REGISTER(ReLUV2Info);
REGISTER(SoftplusInfo);
REGISTER(SoftsignInfo);
REGISTER(GatherV2Info);
REGISTER(SparseGatherV2Info);
REGISTER(SqrtInfo);
REGISTER(SigmoidInfo);
REGISTER(GetNextInfo);
REGISTER(NegInfo);
REGISTER(AbsInfo);
REGISTER(AcoshInfo);
REGISTER(AsinInfo);
REGISTER(AsinhInfo);
REGISTER(AtanInfo);
REGISTER(AtanhInfo);
REGISTER(CeilInfo);
REGISTER(CoshInfo);
REGISTER(Expm1Info);
REGISTER(Log1pInfo);
REGISTER(SinInfo);
REGISTER(SinhInfo);
REGISTER(TanInfo);
REGISTER(RsqrtInfo);
REGISTER(InvInfo);
REGISTER(ReciprocalInfo);
REGISTER(RoundInfo);
REGISTER(FloorInfo);
REGISTER(SignInfo);
REGISTER(ErfInfo);
REGISTER(ErfcInfo);
REGISTER(ZerosLikeInfo);
REGISTER(OnesLikeInfo);
REGISTER(BesselI0eInfo);
REGISTER(BesselI1eInfo);
REGISTER(BatchMatMulInfo);
REGISTER(ExpandDimsInfo);
REGISTER(SqueezeInfo);
REGISTER(SigmoidCrossEntropyWithLogitsInfo);
REGISTER(SquareInfo);
REGISTER(GatherV2PInfo);
REGISTER(EmbeddingLookupInfo);
REGISTER(TileInfo);
REGISTER(StridedSliceInfo);
REGISTER(DropoutInfo);
REGISTER(ConcatInfo);
REGISTER(SplitInfo);
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_DYNAMIC_CREATOR_H_