fixing update arg table

pull/657/head
chuxing 4 years ago
parent bc21fbdb2e
commit 1ccdf2d27c

@ -112,8 +112,9 @@ Status OpTask::GetProfilingArgs(std::string &model_name, std::string &op_name, u
Status OpTask::UpdateRunInfo(const vector<GeTensorDesc> &input_desc, const vector<GeTensorDesc> &output_desc) { Status OpTask::UpdateRunInfo(const vector<GeTensorDesc> &input_desc, const vector<GeTensorDesc> &output_desc) {
return UNSUPPORTED; return UNSUPPORTED;
} }
Status OpTask::UpdateArgTable(const SingleOpModelParam &param) {
auto addresses = BuildTaskUtils::GetAddresses(op_desc_, param); Status OpTask::DoUpdateArgTable(const SingleOpModelParam &param, bool keep_workspace) {
auto addresses = BuildTaskUtils::GetAddresses(op_desc_, param, keep_workspace);
auto all_addresses = BuildTaskUtils::JoinAddresses(addresses); auto all_addresses = BuildTaskUtils::JoinAddresses(addresses);
uintptr_t *arg_base = nullptr; uintptr_t *arg_base = nullptr;
size_t arg_num = 0; size_t arg_num = 0;
@ -132,6 +133,10 @@ Status OpTask::UpdateArgTable(const SingleOpModelParam &param) {
return SUCCESS; return SUCCESS;
} }
Status OpTask::UpdateArgTable(const SingleOpModelParam &param) {
return DoUpdateArgTable(param, true);
}
Status OpTask::LaunchKernel(const vector<GeTensorDesc> &input_desc, Status OpTask::LaunchKernel(const vector<GeTensorDesc> &input_desc,
const vector<DataBuffer> &input_buffers, const vector<DataBuffer> &input_buffers,
vector<GeTensorDesc> &output_desc, vector<GeTensorDesc> &output_desc,
@ -792,10 +797,9 @@ Status AiCpuTask::LaunchKernel(const std::vector<GeTensorDesc> &input_desc,
return SUCCESS; return SUCCESS;
} }
Status AiCpuTask::UpdateArgTable(const SingleOpModelParam &param) { Status AiCpuBaseTask::UpdateArgTable(const SingleOpModelParam &param) {
auto addresses = BuildTaskUtils::GetAddresses(op_desc_, param, false); // aicpu do not have workspace, for now
io_addr_host_ = BuildTaskUtils::JoinAddresses(addresses); return DoUpdateArgTable(param, false);
return SUCCESS;
} }
void AiCpuTask::GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) { void AiCpuTask::GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) {

@ -54,6 +54,8 @@ class OpTask {
rtStream_t stream); rtStream_t stream);
protected: protected:
Status DoUpdateArgTable(const SingleOpModelParam &param, bool keep_workspace);
DumpProperties dump_properties_; DumpProperties dump_properties_;
DumpOp dump_op_; DumpOp dump_op_;
OpDescPtr op_desc_; OpDescPtr op_desc_;
@ -110,7 +112,7 @@ class AiCpuBaseTask : public OpTask {
AiCpuBaseTask() = default; AiCpuBaseTask() = default;
~AiCpuBaseTask() override; ~AiCpuBaseTask() override;
UnknowShapeOpType GetUnknownType() const { return unknown_type_; } UnknowShapeOpType GetUnknownType() const { return unknown_type_; }
Status UpdateArgTable(const SingleOpModelParam &param) override;
protected: protected:
Status UpdateIoAddr(const std::vector<DataBuffer> &inputs, const std::vector<DataBuffer> &outputs); Status UpdateIoAddr(const std::vector<DataBuffer> &inputs, const std::vector<DataBuffer> &outputs);
Status SetInputConst(); Status SetInputConst();
@ -137,7 +139,6 @@ class AiCpuTask : public AiCpuBaseTask {
~AiCpuTask() override; ~AiCpuTask() override;
Status LaunchKernel(rtStream_t stream) override; Status LaunchKernel(rtStream_t stream) override;
Status UpdateArgTable(const SingleOpModelParam &param) override;
void GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) override; void GetIoAddr(uintptr_t *&arg_base, size_t &arg_count) override;
Status LaunchKernel(const std::vector<GeTensorDesc> &input_desc, Status LaunchKernel(const std::vector<GeTensorDesc> &input_desc,

Loading…
Cancel
Save