commit
114c894be2
@ -0,0 +1,94 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_DATA_SOURCE_ACTOR_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_DATA_SOURCE_ACTOR_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <queue>
|
||||
#include "mindrt/include/actor/op_actor.h"
|
||||
#include "mindrt/include/async/future.h"
|
||||
#include "runtime/framework/device_tensor_store.h"
|
||||
#include "runtime/framework/host_tensor_queue.h"
|
||||
#include "base/base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
// The data source actor is used to fetch data and process them into device tensors,
|
||||
// and then send them to kernel actor.
|
||||
class DataSourceActor : public ActorBase {
|
||||
public:
|
||||
DataSourceActor(std::string name, size_t buffer_capacity) : ActorBase(name), buffer_capacity_(buffer_capacity) {}
|
||||
virtual ~DataSourceActor() = default;
|
||||
|
||||
// The process entry of data processing.
|
||||
virtual void FetchData(OpContext<DeviceTensor> *context) = 0;
|
||||
|
||||
protected:
|
||||
// To trigger kernel actors running by op arrows.
|
||||
std::vector<OpArrowPtr> output_op_arrows_;
|
||||
|
||||
// The buffers store the data.
|
||||
std::queue<std::vector<DeviceTensorPtr>> buffers_;
|
||||
size_t buffer_capacity_;
|
||||
|
||||
// The sequential number of corresponding batch data.
|
||||
std::queue<uuids::uuid *> sequential_nums_;
|
||||
};
|
||||
|
||||
// The class represents that the data source is device queue.
|
||||
class DeviceQueueDataSourceActor : public DataSourceActor {
|
||||
public:
|
||||
DeviceQueueDataSourceActor(std::string name, size_t buffer_capacity) : DataSourceActor(name, buffer_capacity) {}
|
||||
virtual ~DeviceQueueDataSourceActor() = default;
|
||||
|
||||
void FetchData(OpContext<DeviceTensor> *context) override;
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
|
||||
// Input data kernel(for example GetNext) fetches data from device queue.
|
||||
CNodePtr data_kernel_;
|
||||
};
|
||||
|
||||
// The class represents that the data source is host queue.
|
||||
class HostQueueDataSourceActor : public DataSourceActor {
|
||||
public:
|
||||
HostQueueDataSourceActor(std::string name, size_t buffer_capacity, HostTensorQueuePtr host_queue)
|
||||
: DataSourceActor(name, buffer_capacity), host_queue_(host_queue) {}
|
||||
virtual ~HostQueueDataSourceActor() = default;
|
||||
|
||||
void FetchData(OpContext<DeviceTensor> *context) override;
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
|
||||
HostTensorQueuePtr host_queue_;
|
||||
// Input data nodes fetch data from host queue.
|
||||
std::vector<AnfNodePtr> data_nodes_;
|
||||
};
|
||||
|
||||
using DataSourceActorPtr = std::shared_ptr<DataSourceActor>;
|
||||
using DeviceQueueDSActorPtr = std::shared_ptr<DeviceQueueDataSourceActor>;
|
||||
using HostQueueDSActorPtr = std::shared_ptr<HostQueueDataSourceActor>;
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_DATA_SOURCE_ACTOR_H_
|
@ -0,0 +1,91 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_KERNEL_ACTOR_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_KERNEL_ACTOR_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
#include "mindrt/include/actor/op_actor.h"
|
||||
#include "runtime/hardware/device_context.h"
|
||||
#include "runtime/framework/device_tensor_store.h"
|
||||
#include "backend/kernel_compiler/kernel.h"
|
||||
#include "ir/anf.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using mindspore::device::DeviceContext;
|
||||
using mindspore::kernel::AddressPtr;
|
||||
|
||||
// The kernel actor is used to receive the device tensors and control info to luanch kernel.
|
||||
class KernelActor : public OpActor<DeviceTensor> {
|
||||
public:
|
||||
KernelActor(std::string name, CNodePtr kernel, const DeviceContext *device_context)
|
||||
: OpActor(name), kernel_(kernel), device_context_(device_context), input_datas_num_(0), input_controls_num_(0) {}
|
||||
virtual ~KernelActor() = default;
|
||||
|
||||
// The kernel actor run when receive the input data.
|
||||
void RunOpData(OpDataPtr<DeviceTensor> input_data, OpContext<DeviceTensor> *context) override;
|
||||
// The kernel actor run when receive the input control.
|
||||
void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) override;
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
|
||||
// Check whether satisfy the condition for launch.
|
||||
bool CheckLaunchCondition(const uuids::uuid *sequential_num);
|
||||
// Fetch the args of kernel launch.
|
||||
void FetchLaunchArgs(std::vector<AddressPtr> *kernel_inputs, std::vector<AddressPtr> *kernel_outputs,
|
||||
std::vector<AddressPtr> *kernel_workspaces);
|
||||
// The real kernel launch processing.
|
||||
void Launch(OpContext<DeviceTensor> *context);
|
||||
// Send output data and output controls when finish kernel launch.
|
||||
void SendOutput(OpContext<DeviceTensor> *context);
|
||||
|
||||
void AllocateMemory(OpContext<DeviceTensor> *context);
|
||||
void FreeMemory(OpContext<DeviceTensor> *context);
|
||||
|
||||
// Fetch the device tensor for launch.
|
||||
void FetchInputDeviceTensor(const uuids::uuid *sequential_num);
|
||||
void FetchOutputDeviceTensor();
|
||||
void FetchWorkspaceDeviceTensor();
|
||||
|
||||
CNodePtr kernel_;
|
||||
// The device interface of kernel launch.
|
||||
const DeviceContext *device_context_;
|
||||
|
||||
// The dependent input data number.
|
||||
size_t input_datas_num_;
|
||||
// The dependent input controls number.
|
||||
size_t input_controls_num_;
|
||||
|
||||
// Pair<index, anfNode> points to the dependent device tensor store, anfNode is the key of the device tensor store.
|
||||
std::vector<std::pair<size_t, void *>> device_tensor_store_keys_;
|
||||
|
||||
// The device tensors for launch.
|
||||
std::vector<DeviceTensorPtr> input_device_tensors_;
|
||||
std::vector<DeviceTensorPtr> output_device_tensors_;
|
||||
std::vector<DeviceTensorPtr> workspace_device_tensors_;
|
||||
};
|
||||
|
||||
using KernelActorPtr = std::shared_ptr<KernelActor>;
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_KERNEL_ACTOR_H_
|
@ -0,0 +1,58 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_LOOP_COUNT_ACTOR_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_LOOP_COUNT_ACTOR_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include "mindrt/include/actor/op_actor.h"
|
||||
#include "runtime/framework/device_tensor_store.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
// The loop count actor is used to receive the control of tail kernel actor to represent the end of one step
|
||||
// and decide whether to loop execution by loop count.
|
||||
class LoopCountActor : public OpActor<DeviceTensor> {
|
||||
public:
|
||||
LoopCountActor(std::string name, size_t loop_count) : OpActor(name), loop_count_(loop_count), current_count_(0) {}
|
||||
virtual ~LoopCountActor() = default;
|
||||
|
||||
// The loop count actor run when receive the input control.
|
||||
void RunOpControl(AID *input_control, OpContext<DeviceTensor> *context) override;
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
|
||||
// The loop count is constant, the current count is increased after each step running finished.
|
||||
size_t loop_count_;
|
||||
size_t current_count_;
|
||||
|
||||
// The dependent input controls number.
|
||||
size_t input_controls_num_;
|
||||
|
||||
// The output controls contain the data source actors and the no input kernel actors.
|
||||
std::vector<AID> data_source_aids_;
|
||||
std::vector<AID> no_input_kernel_aids_;
|
||||
};
|
||||
|
||||
using LoopCountActorPtr = std::shared_ptr<LoopCountActor>;
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_LOOP_COUNT_ACTOR_H_
|
@ -0,0 +1,53 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_MEMORY_MANAGER_ACTOR_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_MEMORY_MANAGER_ACTOR_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "mindrt/include/actor/actor.h"
|
||||
#include "runtime/framework/device_tensor_store.h"
|
||||
#include "runtime/hardware/device_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using mindspore::device::DeviceContext;
|
||||
|
||||
// MemoryManagerActor need response to memory alloc and free quickly, so must bind single thread.
|
||||
class MemoryManagerActor : public ActorBase {
|
||||
public:
|
||||
MemoryManagerActor() : ActorBase("MemoryManagerActor") {}
|
||||
virtual ~MemoryManagerActor() = default;
|
||||
|
||||
static std::shared_ptr<MemoryManagerActor> &GetInstance() {
|
||||
static std::shared_ptr<MemoryManagerActor> instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
// The process entry of memory alloc.
|
||||
bool AllocateMemory(std::vector<DeviceTensorPtr> alloc_list, const DeviceContext *device_context,
|
||||
OpContext<DeviceTensor> *op_context);
|
||||
// The process entry of memory free.
|
||||
void FreeMemory(std::vector<DeviceTensorPtr> free_list, const DeviceContext *device_context,
|
||||
OpContext<DeviceTensor> *op_context);
|
||||
};
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_MEMORY_MANAGER_ACTOR_H_
|
@ -0,0 +1,69 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_STORE_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_STORE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include "runtime/device/device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using DeviceTensor = mindspore::device::DeviceAddress;
|
||||
using DeviceTensorPtr = std::shared_ptr<DeviceTensor>;
|
||||
|
||||
// The device tensor mainly includes address ptr, size and reference count,
|
||||
// which represents the basic data structure of kernel launch and transfers between actors.
|
||||
// Some device tensors (such as weights and value nodes of graph) are fixed addresses and persistent,
|
||||
// so they are more suitable for store and can be obtained when they are used by actor.
|
||||
class DeviceTensorStore {
|
||||
public:
|
||||
DeviceTensorStore() = default;
|
||||
virtual ~DeviceTensorStore() = default;
|
||||
|
||||
static DeviceTensorStore &GetInstance() {
|
||||
static DeviceTensorStore instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
// Support value modifiable, so use the way of array subscript directly.
|
||||
void Insert(void *key, DeviceTensorPtr value) { device_tensors_[key] = value; }
|
||||
|
||||
void Remove(void *key) {
|
||||
auto iter = device_tensors_.find(key);
|
||||
if (iter != device_tensors_.end()) {
|
||||
(void)device_tensors_.erase(iter);
|
||||
}
|
||||
}
|
||||
|
||||
DeviceTensorPtr Fetch(void *key) const {
|
||||
auto iter = device_tensors_.find(key);
|
||||
if (iter != device_tensors_.end()) {
|
||||
return iter->second;
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// The data storage of device tensor, key is anfNode ptr.
|
||||
std::unordered_map<void *, DeviceTensorPtr> device_tensors_;
|
||||
};
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_STORE_H_
|
@ -0,0 +1,111 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
#include "runtime/framework/actor/data_source_actor.h"
|
||||
#include "runtime/framework/actor/loop_count_actor.h"
|
||||
#include "runtime/framework/actor/kernel_actor.h"
|
||||
#include "runtime/hardware/device_context.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using mindspore::device::DeviceContext;
|
||||
|
||||
enum class GraphExecutionStrategy {
|
||||
// The actor running is triggered only by data.
|
||||
kPipeline,
|
||||
// The actor running need be triggered by control in addition.
|
||||
kStep
|
||||
};
|
||||
|
||||
// The actor set generated by graph transformer is the execution unit of actor runtime.
|
||||
// It includes data source actor, kernel actor, loop count actor.
|
||||
// The data source actor is used to obtain data and process them into device tensors,
|
||||
// and then send them to kernel actor. The kernel actor is used to receive the device tensors to luanch kernel.
|
||||
// Specifically notice the no input kernel actor, it means that this actor has no input device tensor, need be triggered
|
||||
// externally. The loop count actor is used to receive the control of tail kernel actor to represent the end of one step
|
||||
// and decide whether to loop execution by loop count.
|
||||
struct ActorSet {
|
||||
std::vector<DataSourceActorPtr> data_source_actors_;
|
||||
std::vector<KernelActorPtr> kernel_actors_;
|
||||
// No input kernel actors need be triggered specifically.
|
||||
std::vector<KernelActorPtr> no_input_kernel_actors_;
|
||||
LoopCountActorPtr loop_count_actor_{nullptr};
|
||||
};
|
||||
using ActorSetPtr = std::shared_ptr<ActorSet>;
|
||||
|
||||
class GraphScheduler {
|
||||
public:
|
||||
GraphScheduler() = default;
|
||||
virtual ~GraphScheduler() = default;
|
||||
|
||||
static GraphScheduler &GetInstance() {
|
||||
static GraphScheduler instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
// Transform graph to actor DAG, contains build and link.
|
||||
ActorSetPtr Transform(const KernelGraphPtr &graph, const DeviceContext *device_context,
|
||||
const std::vector<tensor::TensorPtr> *input_tensors = nullptr,
|
||||
GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline);
|
||||
|
||||
// Schedule actors in the actor runtime. Single machine scheduling is supported currently, and distributed scheduling
|
||||
// will be supported in the future.
|
||||
void Schedule(const ActorSetPtr &actor_set);
|
||||
|
||||
// The processing entry of actors running.
|
||||
bool Run(const ActorSetPtr &actor_set);
|
||||
|
||||
private:
|
||||
// Transform the nodes of graph to actors.
|
||||
ActorSetPtr Build(const KernelGraphPtr &graph, const DeviceContext *device_context);
|
||||
// Link actors to DAG through the edge connection of graph and graph execution strategy.
|
||||
void Link(ActorSetPtr actor_set, const KernelGraphPtr &graph, GraphExecutionStrategy strategy);
|
||||
|
||||
// The processing of actors build.
|
||||
std::vector<DataSourceActorPtr> BuildDataSourceActor(const KernelGraphPtr &graph);
|
||||
std::vector<KernelActorPtr> BuildKernelActor(const KernelGraphPtr &graph, const DeviceContext *device_context);
|
||||
LoopCountActorPtr BuildLoopCountActor(const KernelGraphPtr &graph);
|
||||
|
||||
// The processing of actors link.
|
||||
void LinkDataSourceActor(std::vector<DataSourceActorPtr> actors, const KernelGraphPtr &graph);
|
||||
void LinkKernelActor(std::vector<KernelActorPtr> actors, const KernelGraphPtr &graph,
|
||||
GraphExecutionStrategy strategy);
|
||||
void LinkLoopCountActor(LoopCountActorPtr actor, const KernelGraphPtr &graph);
|
||||
|
||||
// Persist device tensors of graph's some nodes(such as weights and value nodes).
|
||||
void PersistDeviceTensor(const KernelGraphPtr &graph);
|
||||
// Judge whether the device tensor of the node is persistent or not.
|
||||
bool IsPersistentDeviceTensor(const AnfNodePtr &node);
|
||||
|
||||
std::unordered_map<KernelGraphPtr, ActorSetPtr> graph_to_actor_;
|
||||
std::unordered_map<KernelGraphPtr, HostTensorQueuePtr> graph_to_host_queue_;
|
||||
|
||||
// The second element of pair represents the output index of kernel actor corresponding to the device tensor.
|
||||
std::unordered_map<DeviceTensorPtr, std::pair<KernelActorPtr, int>> device_address_to_actor_;
|
||||
};
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_SCHEDULER_H_
|
@ -0,0 +1,55 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_HOST_QUEUE_STORE_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_HOST_QUEUE_STORE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
#include "ir/tensor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using mindspore::tensor::TensorPtr;
|
||||
|
||||
// Host tensor queue is used to store host tensors, and its data will be fetched by the host queue data source actor.
|
||||
class HostTensorQueue {
|
||||
public:
|
||||
HostTensorQueue() = default;
|
||||
virtual ~HostTensorQueue() = default;
|
||||
|
||||
void PushData(std::vector<TensorPtr> tensors) { buffers_.push(tensors); }
|
||||
|
||||
std::vector<TensorPtr> PullData() {
|
||||
if (buffers_.empty()) {
|
||||
std::vector<TensorPtr> empty_tensor;
|
||||
return empty_tensor;
|
||||
}
|
||||
auto tensors = buffers_.front();
|
||||
buffers_.pop();
|
||||
return tensors;
|
||||
}
|
||||
|
||||
private:
|
||||
std::queue<std::vector<TensorPtr>> buffers_;
|
||||
};
|
||||
|
||||
using HostTensorQueuePtr = std::shared_ptr<HostTensorQueue>;
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_HOST_QUEUE_STORE_H_
|
Loading…
Reference in new issue