|
|
|
@ -57,7 +57,6 @@ inline std::string GradVarName(const std::string& var_name) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class OperatorBase;
|
|
|
|
|
class InferShapeContext;
|
|
|
|
|
class ExecutionContext;
|
|
|
|
|
|
|
|
|
|
extern const Tensor* GetTensorFromVar(const Variable* var);
|
|
|
|
@ -169,10 +168,11 @@ class NOP : public OperatorBase {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class InferShapeContext {
|
|
|
|
|
class ExecutionContext {
|
|
|
|
|
public:
|
|
|
|
|
InferShapeContext(const OperatorBase& op, const Scope& scope)
|
|
|
|
|
: op_(op), scope_(scope) {}
|
|
|
|
|
ExecutionContext(const OperatorBase& op, const Scope& scope,
|
|
|
|
|
const platform::DeviceContext& device_context)
|
|
|
|
|
: op_(op), scope_(scope), device_context_(device_context) {}
|
|
|
|
|
|
|
|
|
|
const OperatorBase& op() const { return op_; }
|
|
|
|
|
|
|
|
|
@ -278,31 +278,6 @@ class InferShapeContext {
|
|
|
|
|
out_tensor->set_lod(in_tensor.lod());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
const OperatorBase& op_;
|
|
|
|
|
const Scope& scope_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
const Tensor* InferShapeContext::Input<Tensor>(const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
const std::vector<const Tensor*> InferShapeContext::MultiInput<Tensor>(
|
|
|
|
|
const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
Tensor* InferShapeContext::Output<Tensor>(const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
std::vector<Tensor*> InferShapeContext::MultiOutput<Tensor>(
|
|
|
|
|
const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
class ExecutionContext : public InferShapeContext {
|
|
|
|
|
public:
|
|
|
|
|
ExecutionContext(const OperatorBase& op, const Scope& scope,
|
|
|
|
|
const platform::DeviceContext& device_context)
|
|
|
|
|
: InferShapeContext(op, scope), device_context_(device_context) {}
|
|
|
|
|
|
|
|
|
|
template <typename PlaceType,
|
|
|
|
|
typename DeviceType = typename platform::EigenDeviceConverter<
|
|
|
|
|
PlaceType>::EigenDeviceType>
|
|
|
|
@ -315,9 +290,25 @@ class ExecutionContext : public InferShapeContext {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
const OperatorBase& op_;
|
|
|
|
|
const Scope& scope_;
|
|
|
|
|
const platform::DeviceContext& device_context_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
|
|
|
|
|
const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
|
|
|
|
|
const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
class CompileTimeInferShapeContext : public InferShapeContextBase {
|
|
|
|
|
public:
|
|
|
|
|
CompileTimeInferShapeContext(const OpDescBind& op, const BlockDescBind& block)
|
|
|
|
|