!13269 add actor runtime interface

From: @limingqi107
Reviewed-by: 
Signed-off-by:
pull/13269/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 114c894be2

@ -68,6 +68,11 @@ class DeviceAddress : public mindspore::DeviceSync {
virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; }
virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; }
void *GetMutablePtr() const override { return ptr_; }
void set_ref_count(size_t ref_count) { ref_count_ = ref_count; }
void IncreaseRefCount() { ref_count_++; }
void DecreaseRefCountUsed() { ref_count_dynamic_used_--; }
void ResetRefCountUsed() { ref_count_dynamic_used_ = ref_count_; }
size_t ref_count_dynamic_used() const { return ref_count_dynamic_used_; }
virtual bool DumpMemToFile(bool dump_mode, const std::string &filepath, const std::string &host_fmt,
const ShapeVector &host_shape, TypeId host_type) const {
return true;
@ -85,7 +90,9 @@ class DeviceAddress : public mindspore::DeviceSync {
void set_ptr(void *ptr) { ptr_ = ptr; }
void *ptr_{nullptr};
size_t size_{0};
size_t ref_count_{0};
size_t ref_count_{1};
// It will be decreased in the running, and reset by ref_count_ when it is zero.
size_t ref_count_dynamic_used_{1};
string format_{"DefaultFormat"};
TypeId type_id_{kNumberTypeFloat16};
bool from_mem_pool_{false};

@ -0,0 +1,94 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_DATA_SOURCE_ACTOR_H_
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_DATA_SOURCE_ACTOR_H_
#include <vector>
#include <string>
#include <memory>
#include <unordered_map>
#include <queue>
#include "mindrt/include/actor/op_actor.h"
#include "mindrt/include/async/future.h"
#include "runtime/framework/device_tensor_store.h"
#include "runtime/framework/host_tensor_queue.h"
#include "base/base.h"
namespace mindspore {
namespace runtime {
// The data source actor is used to fetch data and process them into device tensors,
// and then send them to kernel actor.
class DataSourceActor : public ActorBase {
public:
DataSourceActor(std::string name, size_t buffer_capacity) : ActorBase(name), buffer_capacity_(buffer_capacity) {}
virtual ~DataSourceActor() = default;
// The process entry of data processing.
virtual void FetchData(OpContext<DeviceTensor> *context) = 0;
protected:
// To trigger kernel actors running by op arrows.
std::vector<OpArrowPtr> output_op_arrows_;
// The buffers store the data.
std::queue<std::vector<DeviceTensorPtr>> buffers_;
size_t buffer_capacity_;
// The sequential number of corresponding batch data.
std::queue<uuids::uuid *> sequential_nums_;
};
// The class represents that the data source is device queue.
class DeviceQueueDataSourceActor : public DataSourceActor {
public:
DeviceQueueDataSourceActor(std::string name, size_t buffer_capacity) : DataSourceActor(name, buffer_capacity) {}
virtual ~DeviceQueueDataSourceActor() = default;
void FetchData(OpContext<DeviceTensor> *context) override;
private:
friend class GraphScheduler;
// Input data kernel(for example GetNext) fetches data from device queue.
CNodePtr data_kernel_;
};
// The class represents that the data source is host queue.
class HostQueueDataSourceActor : public DataSourceActor {
public:
HostQueueDataSourceActor(std::string name, size_t buffer_capacity, HostTensorQueuePtr host_queue)
: DataSourceActor(name, buffer_capacity), host_queue_(host_queue) {}
virtual ~HostQueueDataSourceActor() = default;
void FetchData(OpContext<DeviceTensor> *context) override;
private:
friend class GraphScheduler;
HostTensorQueuePtr host_queue_;
// Input data nodes fetch data from host queue.
std::vector<AnfNodePtr> data_nodes_;
};
using DataSourceActorPtr = std::shared_ptr<DataSourceActor>;
using DeviceQueueDSActorPtr = std::shared_ptr<DeviceQueueDataSourceActor>;
using HostQueueDSActorPtr = std::shared_ptr<HostQueueDataSourceActor>;
} // namespace runtime
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_DATA_SOURCE_ACTOR_H_

@ -0,0 +1,91 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_KERNEL_ACTOR_H_
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_KERNEL_ACTOR_H_
#include <vector>
#include <string>
#include <memory>
#include <utility>
#include <unordered_map>
#include "mindrt/include/actor/op_actor.h"
#include "runtime/hardware/device_context.h"
#include "runtime/framework/device_tensor_store.h"
#include "backend/kernel_compiler/kernel.h"
#include "ir/anf.h"
namespace mindspore {
namespace runtime {
using mindspore::device::DeviceContext;
using mindspore::kernel::AddressPtr;
// The kernel actor is used to receive the device tensors and control info to luanch kernel.
class KernelActor : public OpActor<DeviceTensor> {
public:
KernelActor(std::string name, CNodePtr kernel, const DeviceContext *device_context)
: OpActor(name), kernel_(kernel), device_context_(device_context), input_datas_num_(0), input_controls_num_(0) {}
virtual ~KernelActor() = default;
// The kernel actor run when receive the input data.
void RunOpData(OpDataPtr<DeviceTensor> input_data, OpContext<DeviceTensor> *context) override;
// The kernel actor run when receive the input control.
void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) override;
private:
friend class GraphScheduler;
// Check whether satisfy the condition for launch.
bool CheckLaunchCondition(const uuids::uuid *sequential_num);
// Fetch the args of kernel launch.
void FetchLaunchArgs(std::vector<AddressPtr> *kernel_inputs, std::vector<AddressPtr> *kernel_outputs,
std::vector<AddressPtr> *kernel_workspaces);
// The real kernel launch processing.
void Launch(OpContext<DeviceTensor> *context);
// Send output data and output controls when finish kernel launch.
void SendOutput(OpContext<DeviceTensor> *context);
void AllocateMemory(OpContext<DeviceTensor> *context);
void FreeMemory(OpContext<DeviceTensor> *context);
// Fetch the device tensor for launch.
void FetchInputDeviceTensor(const uuids::uuid *sequential_num);
void FetchOutputDeviceTensor();
void FetchWorkspaceDeviceTensor();
CNodePtr kernel_;
// The device interface of kernel launch.
const DeviceContext *device_context_;
// The dependent input data number.
size_t input_datas_num_;
// The dependent input controls number.
size_t input_controls_num_;
// Pair<index, anfNode> points to the dependent device tensor store, anfNode is the key of the device tensor store.
std::vector<std::pair<size_t, void *>> device_tensor_store_keys_;
// The device tensors for launch.
std::vector<DeviceTensorPtr> input_device_tensors_;
std::vector<DeviceTensorPtr> output_device_tensors_;
std::vector<DeviceTensorPtr> workspace_device_tensors_;
};
using KernelActorPtr = std::shared_ptr<KernelActor>;
} // namespace runtime
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_KERNEL_ACTOR_H_

@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_LOOP_COUNT_ACTOR_H_
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_LOOP_COUNT_ACTOR_H_
#include <vector>
#include <string>
#include <memory>
#include <unordered_map>
#include "mindrt/include/actor/op_actor.h"
#include "runtime/framework/device_tensor_store.h"
namespace mindspore {
namespace runtime {
// The loop count actor is used to receive the control of tail kernel actor to represent the end of one step
// and decide whether to loop execution by loop count.
class LoopCountActor : public OpActor<DeviceTensor> {
public:
LoopCountActor(std::string name, size_t loop_count) : OpActor(name), loop_count_(loop_count), current_count_(0) {}
virtual ~LoopCountActor() = default;
// The loop count actor run when receive the input control.
void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) override;
private:
friend class GraphScheduler;
// The loop count is constant, the current count is increased after each step running finished.
size_t loop_count_;
size_t current_count_;
// The dependent input controls number.
size_t input_controls_num_;
// The output controls contain the data source actors and the no input kernel actors.
std::vector<AID> data_source_aids_;
std::vector<AID> no_input_kernel_aids_;
};
using LoopCountActorPtr = std::shared_ptr<LoopCountActor>;
} // namespace runtime
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_LOOP_COUNT_ACTOR_H_

@ -0,0 +1,53 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_MEMORY_MANAGER_ACTOR_H_
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_MEMORY_MANAGER_ACTOR_H_
#include <vector>
#include <memory>
#include <string>
#include <unordered_map>
#include "mindrt/include/actor/actor.h"
#include "runtime/framework/device_tensor_store.h"
#include "runtime/hardware/device_context.h"
namespace mindspore {
namespace runtime {
using mindspore::device::DeviceContext;
// MemoryManagerActor need response to memory alloc and free quickly, so must bind single thread.
class MemoryManagerActor : public ActorBase {
public:
MemoryManagerActor() : ActorBase("MemoryManagerActor") {}
virtual ~MemoryManagerActor() = default;
static std::shared_ptr<MemoryManagerActor> &GetInstance() {
static std::shared_ptr<MemoryManagerActor> instance;
return instance;
}
// The process entry of memory alloc.
bool AllocateMemory(std::vector<DeviceTensorPtr> alloc_list, const DeviceContext *device_context,
OpContext<DeviceTensor> *op_context);
// The process entry of memory free.
void FreeMemory(std::vector<DeviceTensorPtr> free_list, const DeviceContext *device_context,
OpContext<DeviceTensor> *op_context);
};
} // namespace runtime
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_MEMORY_MANAGER_ACTOR_H_

@ -0,0 +1,69 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_STORE_H_
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_STORE_H_
#include <memory>
#include <unordered_map>
#include "runtime/device/device_address.h"
namespace mindspore {
namespace runtime {
using DeviceTensor = mindspore::device::DeviceAddress;
using DeviceTensorPtr = std::shared_ptr<DeviceTensor>;
// The device tensor mainly includes address ptr, size and reference count,
// which represents the basic data structure of kernel launch and transfers between actors.
// Some device tensors (such as weights and value nodes of graph) are fixed addresses and persistent,
// so they are more suitable for store and can be obtained when they are used by actor.
class DeviceTensorStore {
public:
DeviceTensorStore() = default;
virtual ~DeviceTensorStore() = default;
static DeviceTensorStore &GetInstance() {
static DeviceTensorStore instance;
return instance;
}
// Support value modifiable, so use the way of array subscript directly.
void Insert(void *key, DeviceTensorPtr value) { device_tensors_[key] = value; }
void Remove(void *key) {
auto iter = device_tensors_.find(key);
if (iter != device_tensors_.end()) {
(void)device_tensors_.erase(iter);
}
}
DeviceTensorPtr Fetch(void *key) const {
auto iter = device_tensors_.find(key);
if (iter != device_tensors_.end()) {
return iter->second;
} else {
return nullptr;
}
}
private:
// The data storage of device tensor, key is anfNode ptr.
std::unordered_map<void *, DeviceTensorPtr> device_tensors_;
};
} // namespace runtime
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_STORE_H_

@ -0,0 +1,111 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_
#include <vector>
#include <string>
#include <memory>
#include <utility>
#include <unordered_map>
#include "runtime/framework/actor/data_source_actor.h"
#include "runtime/framework/actor/loop_count_actor.h"
#include "runtime/framework/actor/kernel_actor.h"
#include "runtime/hardware/device_context.h"
#include "backend/session/kernel_graph.h"
namespace mindspore {
namespace runtime {
using mindspore::device::DeviceContext;
enum class GraphExecutionStrategy {
// The actor running is triggered only by data.
kPipeline,
// The actor running need be triggered by control in addition.
kStep
};
// The actor set generated by graph transformer is the execution unit of actor runtime.
// It includes data source actor, kernel actor, loop count actor.
// The data source actor is used to obtain data and process them into device tensors,
// and then send them to kernel actor. The kernel actor is used to receive the device tensors to luanch kernel.
// Specifically notice the no input kernel actor, it means that this actor has no input device tensor, need be triggered
// externally. The loop count actor is used to receive the control of tail kernel actor to represent the end of one step
// and decide whether to loop execution by loop count.
struct ActorSet {
std::vector<DataSourceActorPtr> data_source_actors_;
std::vector<KernelActorPtr> kernel_actors_;
// No input kernel actors need be triggered specifically.
std::vector<KernelActorPtr> no_input_kernel_actors_;
LoopCountActorPtr loop_count_actor_{nullptr};
};
using ActorSetPtr = std::shared_ptr<ActorSet>;
class GraphScheduler {
public:
GraphScheduler() = default;
virtual ~GraphScheduler() = default;
static GraphScheduler &GetInstance() {
static GraphScheduler instance;
return instance;
}
// Transform graph to actor DAG, contains build and link.
ActorSetPtr Transform(const KernelGraphPtr &graph, const DeviceContext *device_context,
const std::vector<tensor::TensorPtr> *input_tensors = nullptr,
GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline);
// Schedule actors in the actor runtime. Single machine scheduling is supported currently, and distributed scheduling
// will be supported in the future.
void Schedule(const ActorSetPtr &actor_set);
// The processing entry of actors running.
bool Run(const ActorSetPtr &actor_set);
private:
// Transform the nodes of graph to actors.
ActorSetPtr Build(const KernelGraphPtr &graph, const DeviceContext *device_context);
// Link actors to DAG through the edge connection of graph and graph execution strategy.
void Link(ActorSetPtr actor_set, const KernelGraphPtr &graph, GraphExecutionStrategy strategy);
// The processing of actors build.
std::vector<DataSourceActorPtr> BuildDataSourceActor(const KernelGraphPtr &graph);
std::vector<KernelActorPtr> BuildKernelActor(const KernelGraphPtr &graph, const DeviceContext *device_context);
LoopCountActorPtr BuildLoopCountActor(const KernelGraphPtr &graph);
// The processing of actors link.
void LinkDataSourceActor(std::vector<DataSourceActorPtr> actors, const KernelGraphPtr &graph);
void LinkKernelActor(std::vector<KernelActorPtr> actors, const KernelGraphPtr &graph,
GraphExecutionStrategy strategy);
void LinkLoopCountActor(LoopCountActorPtr actor, const KernelGraphPtr &graph);
// Persist device tensors of graph's some nodes(such as weights and value nodes).
void PersistDeviceTensor(const KernelGraphPtr &graph);
// Judge whether the device tensor of the node is persistent or not.
bool IsPersistentDeviceTensor(const AnfNodePtr &node);
std::unordered_map<KernelGraphPtr, ActorSetPtr> graph_to_actor_;
std::unordered_map<KernelGraphPtr, HostTensorQueuePtr> graph_to_host_queue_;
// The second element of pair represents the output index of kernel actor corresponding to the device tensor.
std::unordered_map<DeviceTensorPtr, std::pair<KernelActorPtr, int>> device_address_to_actor_;
};
} // namespace runtime
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_

@ -0,0 +1,55 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_HOST_QUEUE_STORE_H_
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_HOST_QUEUE_STORE_H_
#include <memory>
#include <vector>
#include <queue>
#include "ir/tensor.h"
namespace mindspore {
namespace runtime {
using mindspore::tensor::TensorPtr;
// Host tensor queue is used to store host tensors, and its data will be fetched by the host queue data source actor.
class HostTensorQueue {
public:
HostTensorQueue() = default;
virtual ~HostTensorQueue() = default;
void PushData(std::vector<TensorPtr> tensors) { buffers_.push(tensors); }
std::vector<TensorPtr> PullData() {
if (buffers_.empty()) {
std::vector<TensorPtr> empty_tensor;
return empty_tensor;
}
auto tensors = buffers_.front();
buffers_.pop();
return tensors;
}
private:
std::queue<std::vector<TensorPtr>> buffers_;
};
using HostTensorQueuePtr = std::shared_ptr<HostTensorQueue>;
} // namespace runtime
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_HOST_QUEUE_STORE_H_

@ -14,6 +14,9 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_MINDRT_INCLUDE_ACTOR_OP_ACTOR_H
#define MINDSPORE_CORE_MINDRT_INCLUDE_ACTOR_OP_ACTOR_H
#include <list>
#include <vector>
#include <memory>
@ -68,11 +71,21 @@ class OpActor : public ActorBase {
public:
explicit OpActor(std::string op_name) : ActorBase(op_name) {}
virtual ~OpActor() = default;
virtual void OpRun(OpDataPtr<T> inputs, OpContext<T> *context = nullptr) {}
// The op actor run when receive the input data.
virtual void RunOpData(OpDataPtr<T> input_data, OpContext<T> *context = nullptr) {}
// The op actor run when receive the input control.
virtual void RunOpControl(AID *input_control, OpContext<T> *context = nullptr) {}
protected:
// The op data.
std::unordered_map<uuids::uuid *, std::vector<OpDataPtr<T>>> input_op_datas_;
std::vector<OpArrowPtr> output_op_arrow_;
std::vector<OpArrowPtr> output_op_arrows_;
// The op controls.
std::unordered_map<uuids::uuid *, std::vector<AID *>> input_op_controls_;
std::vector<AID> output_op_controls_;
};
template <typename T>
@ -84,7 +97,7 @@ Future<std::list<int>> MindrtAsyncRun(const std::vector<OpDataPtr<T>> &inputData
Future<std::list<int>> collect = mindspore::Collect<int>(futures);
for (auto data : inputData) {
Async(data->op_id_, &mindspore::OpActor<T>::OpRun, data, context);
Async(data->op_id_, &mindspore::OpActor<T>::RunOpData, data, context);
}
return collect;
@ -112,3 +125,5 @@ int MindrtRun(const std::vector<OpDataPtr<T>> &inputData, std::vector<OpDataPtr<
}
} // namespace mindspore
#endif // MINDSPORE_CORE_MINDRT_INCLUDE_ACTOR_OP_ACTOR_H

@ -40,7 +40,7 @@ int LiteOpActor::CompileArrow() {
MS_LOG(ERROR) << "create OpArrow failed, out kernel: " << out->name();
return RET_ERROR;
}
output_op_arrow_.emplace_back(std::move(arrow));
output_op_arrows_.emplace_back(std::move(arrow));
break;
}
}

@ -36,7 +36,7 @@ class LiteOpActor : public OpActor<lite::Tensor> {
public:
explicit LiteOpActor(kernel::LiteKernel *kernel) : OpActor<lite::Tensor>(kernel->name()), kernel_(kernel) {}
virtual ~LiteOpActor() = default;
virtual void OpRun(OpDataPtr<Tensor> inputs, OpContext<Tensor> *context = nullptr) {
virtual void RunOpData(OpDataPtr<Tensor> inputs, OpContext<Tensor> *context = nullptr) {
auto op_uuid = context->sequential_num_;
input_op_datas_[op_uuid].push_back(inputs);
if (input_op_datas_[op_uuid].size() < kernel_->in_tensors().size()) {

Loading…
Cancel
Save