|
|
|
@ -32,8 +32,8 @@ namespace parallel {
|
|
|
|
|
class ActivationBase : public OperatorInfo {
|
|
|
|
|
public:
|
|
|
|
|
ActivationBase(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
|
|
|
|
const PrimitiveAttrs& attrs)
|
|
|
|
|
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs) {}
|
|
|
|
|
const PrimitiveAttrs& attrs, OperatorCostPtr cost)
|
|
|
|
|
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {}
|
|
|
|
|
~ActivationBase() override = default;
|
|
|
|
|
|
|
|
|
|
Status Init(const StrategyPtr& strategy) override;
|
|
|
|
@ -51,19 +51,13 @@ class Activation : public ActivationBase {
|
|
|
|
|
public:
|
|
|
|
|
Activation(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
|
|
|
|
const PrimitiveAttrs& attrs)
|
|
|
|
|
: ActivationBase(name, inputs_shape, outputs_shape, attrs) {
|
|
|
|
|
ac_cost_ptr_ = std::make_shared<ActivationCost>();
|
|
|
|
|
}
|
|
|
|
|
: ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ActivationCost>()) {}
|
|
|
|
|
~Activation() override = default;
|
|
|
|
|
Status GenerateStrategies(int32_t stage_id) override;
|
|
|
|
|
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
|
|
|
|
|
OperatorCostPtr GetOperatorCost() const override { return ac_cost_ptr_; }
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
Status CheckStrategy(const StrategyPtr& strategy) override;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
ActivationCostPtr ac_cost_ptr_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class ActivationInfo : public Activation {
|
|
|
|
@ -108,13 +102,10 @@ class Softmax : public ActivationBase {
|
|
|
|
|
public:
|
|
|
|
|
explicit Softmax(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
|
|
|
|
|
const PrimitiveAttrs& attrs)
|
|
|
|
|
: ActivationBase(name, inputs_shape, outputs_shape, attrs) {
|
|
|
|
|
sm_cost_ptr_ = std::make_shared<SoftmaxCost>();
|
|
|
|
|
}
|
|
|
|
|
: ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCost>()) {}
|
|
|
|
|
~Softmax() override = default;
|
|
|
|
|
Status GenerateStrategies(int32_t stage_id) override;
|
|
|
|
|
Status SetCostUnderStrategy(const StrategyPtr& strategy) override;
|
|
|
|
|
OperatorCostPtr GetOperatorCost() const override { return sm_cost_ptr_; }
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
Status CheckStrategy(const StrategyPtr& strategy) override;
|
|
|
|
@ -122,7 +113,6 @@ class Softmax : public ActivationBase {
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::vector<int32_t> axis_;
|
|
|
|
|
SoftmaxCostPtr sm_cost_ptr_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class SoftmaxInfo : public Softmax {
|
|
|
|
|