|
|
|
@ -309,7 +309,7 @@ template <>
|
|
|
|
|
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
|
|
|
|
|
const std::string& name) const;
|
|
|
|
|
|
|
|
|
|
class CompileTimeInferShapeContext : public InferShapeContextBase {
|
|
|
|
|
class CompileTimeInferShapeContext : public InferShapeContext {
|
|
|
|
|
public:
|
|
|
|
|
CompileTimeInferShapeContext(const OpDescBind& op, const BlockDescBind& block)
|
|
|
|
|
: op_(op), block_(block) {}
|
|
|
|
@ -405,7 +405,7 @@ class CompileTimeInferShapeContext : public InferShapeContextBase {
|
|
|
|
|
const BlockDescBind& block_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class RuntimeInferShapeContext : public InferShapeContextBase {
|
|
|
|
|
class RuntimeInferShapeContext : public InferShapeContext {
|
|
|
|
|
public:
|
|
|
|
|
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
|
|
|
|
|
: op_(op), scope_(scope) {}
|
|
|
|
@ -603,7 +603,7 @@ class OperatorWithKernel : public OperatorBase {
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
virtual void InferShape(InferShapeContextBase* ctx) const = 0;
|
|
|
|
|
virtual void InferShape(InferShapeContext* ctx) const = 0;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
// indicate kernel DataType by input data. Defaultly all input data must be
|
|
|
|
|