|
|
@ -109,9 +109,9 @@ class LiteKernel {
|
|
|
|
|
|
|
|
|
|
|
|
virtual bool IsEval() const { return !this->train_mode_; }
|
|
|
|
virtual bool IsEval() const { return !this->train_mode_; }
|
|
|
|
|
|
|
|
|
|
|
|
virtual void SetTrainable(bool trainable = true) { this->trainable_ = trainable; }
|
|
|
|
virtual void set_trainable(bool trainable = true) { this->trainable_ = trainable; }
|
|
|
|
|
|
|
|
|
|
|
|
virtual bool IsTrainable() const { return this->trainable_; }
|
|
|
|
virtual bool is_trainable() const { return this->trainable_; }
|
|
|
|
|
|
|
|
|
|
|
|
void set_name(const std::string &name) { this->name_ = name; }
|
|
|
|
void set_name(const std::string &name) { this->name_ = name; }
|
|
|
|
|
|
|
|
|
|
|
@ -146,9 +146,9 @@ class LiteKernel {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void SetInKernel(const std::vector<LiteKernel *> &kernel) { this->in_kernels_ = kernel; }
|
|
|
|
void set_in_kernel(const std::vector<LiteKernel *> &kernel) { this->in_kernels_ = kernel; }
|
|
|
|
|
|
|
|
|
|
|
|
void SetOutKernel(const std::vector<LiteKernel *> &kernel) { this->out_kernels_ = kernel; }
|
|
|
|
void set_out_kernel(const std::vector<LiteKernel *> &kernel) { this->out_kernels_ = kernel; }
|
|
|
|
|
|
|
|
|
|
|
|
const std::vector<LiteKernel *> &in_kernels() const { return this->in_kernels_; }
|
|
|
|
const std::vector<LiteKernel *> &in_kernels() const { return this->in_kernels_; }
|
|
|
|
|
|
|
|
|
|
|
@ -165,18 +165,18 @@ class LiteKernel {
|
|
|
|
void set_desc(const KernelKey kernel_key) { desc_ = kernel_key; }
|
|
|
|
void set_desc(const KernelKey kernel_key) { desc_ = kernel_key; }
|
|
|
|
|
|
|
|
|
|
|
|
const mindspore::lite::PrimitiveC *GetPrimitive() const { return primitive_; }
|
|
|
|
const mindspore::lite::PrimitiveC *GetPrimitive() const { return primitive_; }
|
|
|
|
void SetWorkspaceSize(size_t value) { workspace_size_ = value; }
|
|
|
|
void set_workspace_size(size_t value) { workspace_size_ = value; }
|
|
|
|
size_t GetWorkspaceSize() { return workspace_size_; }
|
|
|
|
size_t workspace_size() { return workspace_size_; }
|
|
|
|
static void AllocWorkspace(size_t size);
|
|
|
|
static void AllocWorkspace(size_t size);
|
|
|
|
static void FreeWorkspace();
|
|
|
|
static void FreeWorkspace();
|
|
|
|
void *GetWorkspace() { return workspace_; }
|
|
|
|
void *workspace() { return workspace_; }
|
|
|
|
|
|
|
|
|
|
|
|
SubGraphType subgraph_type() const { return this->subgraph_type_; }
|
|
|
|
SubGraphType subgraph_type() const { return this->subgraph_type_; }
|
|
|
|
|
|
|
|
|
|
|
|
virtual std::string ToString() const;
|
|
|
|
virtual std::string ToString() const;
|
|
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
protected:
|
|
|
|
bool InferShapeDone() { return !(primitive_ != nullptr && !primitive_->GetInferFlag()); }
|
|
|
|
bool InferShapeDone() { return !(primitive_ != nullptr && !primitive_->infer_flag()); }
|
|
|
|
|
|
|
|
|
|
|
|
KernelKey desc_{};
|
|
|
|
KernelKey desc_{};
|
|
|
|
std::string name_;
|
|
|
|
std::string name_;
|
|
|
|