MindData profiler infrastructure

pull/1978/head
Alexey Shevlyakov 5 years ago committed by Junhan Hu
parent 86253c342a
commit 4e333a2f22

@ -62,6 +62,7 @@ add_dependencies(engine-datasetops-source core)
add_dependencies(engine-datasetops-source-sampler core) add_dependencies(engine-datasetops-source-sampler core)
add_dependencies(engine-datasetops core) add_dependencies(engine-datasetops core)
add_dependencies(engine-opt core) add_dependencies(engine-opt core)
add_dependencies(engine-perf core)
add_dependencies(engine-gnn core) add_dependencies(engine-gnn core)
add_dependencies(engine core) add_dependencies(engine core)
add_dependencies(text core) add_dependencies(text core)
@ -81,6 +82,7 @@ set(submodules
$<TARGET_OBJECTS:engine-datasetops-source> $<TARGET_OBJECTS:engine-datasetops-source>
$<TARGET_OBJECTS:engine-datasetops-source-sampler> $<TARGET_OBJECTS:engine-datasetops-source-sampler>
$<TARGET_OBJECTS:engine-gnn> $<TARGET_OBJECTS:engine-gnn>
$<TARGET_OBJECTS:engine-perf>
$<TARGET_OBJECTS:engine-datasetops> $<TARGET_OBJECTS:engine-datasetops>
$<TARGET_OBJECTS:engine-opt> $<TARGET_OBJECTS:engine-opt>
$<TARGET_OBJECTS:engine> $<TARGET_OBJECTS:engine>

@ -239,11 +239,13 @@ void bindTensor(py::module *m) {
.def("set_worker_connector_size", &ConfigManager::set_worker_connector_size) .def("set_worker_connector_size", &ConfigManager::set_worker_connector_size)
.def("set_op_connector_size", &ConfigManager::set_op_connector_size) .def("set_op_connector_size", &ConfigManager::set_op_connector_size)
.def("set_seed", &ConfigManager::set_seed) .def("set_seed", &ConfigManager::set_seed)
.def("set_monitor_sampling_interval", &ConfigManager::set_monitor_sampling_interval)
.def("get_rows_per_buffer", &ConfigManager::rows_per_buffer) .def("get_rows_per_buffer", &ConfigManager::rows_per_buffer)
.def("get_num_parallel_workers", &ConfigManager::num_parallel_workers) .def("get_num_parallel_workers", &ConfigManager::num_parallel_workers)
.def("get_worker_connector_size", &ConfigManager::worker_connector_size) .def("get_worker_connector_size", &ConfigManager::worker_connector_size)
.def("get_op_connector_size", &ConfigManager::op_connector_size) .def("get_op_connector_size", &ConfigManager::op_connector_size)
.def("get_seed", &ConfigManager::seed) .def("get_seed", &ConfigManager::seed)
.def("get_monitor_sampling_interval", &ConfigManager::monitor_sampling_interval)
.def("load", [](ConfigManager &c, std::string s) { (void)c.LoadFile(s); }); .def("load", [](ConfigManager &c, std::string s) { (void)c.LoadFile(s); });
(void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor", py::buffer_protocol()) (void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor", py::buffer_protocol())

@ -88,5 +88,7 @@ void ConfigManager::set_op_connector_size(int32_t connector_size) { op_connector
uint32_t ConfigManager::seed() const { return seed_; } uint32_t ConfigManager::seed() const { return seed_; }
void ConfigManager::set_seed(uint32_t seed) { seed_ = seed; } void ConfigManager::set_seed(uint32_t seed) { seed_ = seed; }
void ConfigManager::set_monitor_sampling_interval(uint32_t interval) { monitor_sampling_interval_ = interval; }
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -111,12 +111,21 @@ class ConfigManager {
// @param seed - The default seed to use // @param seed - The default seed to use
void set_seed(uint32_t seed); void set_seed(uint32_t seed);
// setter function
// @param interval - The setting to apply to the config
void set_monitor_sampling_interval(uint32_t interval);
// getter function
// @return The iterval of monitor sampling
int32_t monitor_sampling_interval() const { return monitor_sampling_interval_; }
private: private:
int32_t rows_per_buffer_{kCfgRowsPerBuffer}; int32_t rows_per_buffer_{kCfgRowsPerBuffer};
int32_t num_parallel_workers_{kCfgParallelWorkers}; int32_t num_parallel_workers_{kCfgParallelWorkers};
int32_t worker_connector_size_{kCfgWorkerConnectorSize}; int32_t worker_connector_size_{kCfgWorkerConnectorSize};
int32_t op_connector_size_{kCfgOpConnectorSize}; int32_t op_connector_size_{kCfgOpConnectorSize};
uint32_t seed_{kCfgDefaultSeed}; uint32_t seed_{kCfgDefaultSeed};
uint32_t monitor_sampling_interval_{kCfgMonitorSamplingInterval};
// Private helper function that taks a nlohmann json format and populates the settings // Private helper function that taks a nlohmann json format and populates the settings
// @param j - The json nlohmann json info // @param j - The json nlohmann json info

@ -47,6 +47,7 @@ constexpr uint32_t kCfgParallelWorkers = 4;
constexpr uint32_t kCfgWorkerConnectorSize = 16; constexpr uint32_t kCfgWorkerConnectorSize = 16;
constexpr uint32_t kCfgOpConnectorSize = 16; constexpr uint32_t kCfgOpConnectorSize = 16;
constexpr uint32_t kCfgDefaultSeed = std::mt19937::default_seed; constexpr uint32_t kCfgDefaultSeed = std::mt19937::default_seed;
constexpr uint32_t kCfgMonitorSamplingInterval = 10;
// Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h) // Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h)
constexpr uint8_t kCVInvalidType = 255; constexpr uint8_t kCVInvalidType = 255;

@ -1,6 +1,7 @@
add_subdirectory(datasetops) add_subdirectory(datasetops)
add_subdirectory(opt) add_subdirectory(opt)
add_subdirectory(gnn) add_subdirectory(gnn)
add_subdirectory(perf)
if (ENABLE_TDTQUE) if (ENABLE_TDTQUE)
add_subdirectory(tdt) add_subdirectory(tdt)
endif () endif ()
@ -16,7 +17,7 @@ add_library(engine OBJECT
target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS}) target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS})
if (ENABLE_TDTQUE) if (ENABLE_TDTQUE)
add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn) add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn engine-perf)
else() else()
add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn) add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf)
endif () endif ()

@ -83,7 +83,14 @@ Status IteratorBase::FetchNextTensorRow(TensorRow *out_row) {
} }
// Constructor of the DatasetIterator // Constructor of the DatasetIterator
DatasetIterator::DatasetIterator(std::shared_ptr<ExecutionTree> exe_tree) : IteratorBase(), root_(exe_tree->root()) {} DatasetIterator::DatasetIterator(std::shared_ptr<ExecutionTree> exe_tree)
: IteratorBase(), root_(exe_tree->root()), tracing_(nullptr), cur_batch_num_(0), cur_connector_size_(0) {
std::shared_ptr<Tracing> node;
Status s = exe_tree->GetProfilingManager()->GetTracingNode(kDatasetIteratorTracingName, &node);
if (s.IsOk()) {
tracing_ = std::dynamic_pointer_cast<DatasetIteratorTracing>(node);
}
}
DatasetIterator::~DatasetIterator() = default; DatasetIterator::~DatasetIterator() = default;
@ -101,6 +108,10 @@ Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) {
// Check if we need to get a new DataBuffer to iterate. // Check if we need to get a new DataBuffer to iterate.
if (curr_buffer_ == nullptr || curr_buffer_->NumRows() == 0) { if (curr_buffer_ == nullptr || curr_buffer_->NumRows() == 0) {
if (tracing_ != nullptr) {
cur_connector_size_ = root_->ConnectorSize();
cur_connector_capacity_ = root_->ConnectorCapacity();
}
RETURN_IF_NOT_OK(root_->GetNextBuffer(&curr_buffer_)); RETURN_IF_NOT_OK(root_->GetNextBuffer(&curr_buffer_));
// Since GetNextBuffer was used rather than GetNextInput(), it means we need to manually // Since GetNextBuffer was used rather than GetNextInput(), it means we need to manually
@ -121,6 +132,8 @@ Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) {
} }
eof_handled_ = true; eof_handled_ = true;
curr_buffer_.reset(); // explicitly free the eof buffer curr_buffer_.reset(); // explicitly free the eof buffer
// Set tree to Finished state
root_->Tree()->SetFinished();
return Status::OK(); return Status::OK();
} }
@ -131,13 +144,18 @@ Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) {
// flow of an eof up the pipeline by itself. // flow of an eof up the pipeline by itself.
eof_handled_ = true; eof_handled_ = true;
curr_buffer_.reset(); // explicitly free the eof buffer curr_buffer_.reset(); // explicitly free the eof buffer
// Set tree to Finished state
root_->Tree()->SetFinished();
return Status::OK(); return Status::OK();
} }
} }
// If we got this far, now it's time to pop that next row for return to caller // If we got this far, now it's time to pop that next row for return to caller
RETURN_IF_NOT_OK(curr_buffer_->PopRow(out_row)); RETURN_IF_NOT_OK(curr_buffer_->PopRow(out_row));
if (tracing_ != nullptr) {
cur_batch_num_++;
tracing_->Record(CONNECTOR_DEPTH, cur_connector_capacity_, cur_batch_num_, cur_connector_size_);
}
return Status::OK(); return Status::OK();
} }

@ -24,6 +24,7 @@
#include "dataset/core/tensor.h" #include "dataset/core/tensor.h"
#include "dataset/engine/datasetops/dataset_op.h" #include "dataset/engine/datasetops/dataset_op.h"
#include "dataset/engine/execution_tree.h" #include "dataset/engine/execution_tree.h"
#include "dataset/engine/perf/dataset_iterator_tracing.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
@ -109,6 +110,10 @@ class DatasetIterator : public IteratorBase {
private: private:
std::shared_ptr<DatasetOp> root_; // saves the root of the executionTree std::shared_ptr<DatasetOp> root_; // saves the root of the executionTree
TensorRow device_queue_row_; TensorRow device_queue_row_;
std::shared_ptr<DatasetIteratorTracing> tracing_; // trace profiling data
int32_t cur_batch_num_; // current batch number,used for profiling
int32_t cur_connector_size_; // current connector size of root op,used for profiling
int32_t cur_connector_capacity_; // current connector capacity of root op, used for profiling
}; };
// The ChildIterator derived class is for fetching rows from intermediate nodes of execution tree. // The ChildIterator derived class is for fetching rows from intermediate nodes of execution tree.

@ -189,6 +189,10 @@ class BatchOp : public ParallelOp {
// @return - Status of the node visit. // @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override; Status Accept(NodePass *p, bool *modified) override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "BatchOp"; }
private: private:
// Worker thread for doing the memcpy of batch // Worker thread for doing the memcpy of batch
// @param int32_t param workerId // @param int32_t param workerId

@ -81,6 +81,10 @@ class ConcatOp : public PipelineOp {
// before providing their own implementations. // before providing their own implementations.
Status PrepareNodePostAction() override; Status PrepareNodePostAction() override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "ConcatOp"; }
private: private:
Status Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf); Status Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf);

@ -38,6 +38,7 @@ DatasetOp::DatasetOp(int32_t op_connector_size)
tree_(nullptr), tree_(nullptr),
state_(OpState::kDeOpIdle), state_(OpState::kDeOpIdle),
op_ctrl_flags_(kDeOpNone), op_ctrl_flags_(kDeOpNone),
out_connector_(nullptr),
first_fetch_(true) { first_fetch_(true) {
// The operator starts out with an invalid operator id. The only way to // The operator starts out with an invalid operator id. The only way to
// get it out of invalid state is to assign the operator to an execution tree. // get it out of invalid state is to assign the operator to an execution tree.

@ -51,7 +51,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
}; };
// Flags that control operator runtime behaviours // Flags that control operator runtime behaviours
enum OpState { kDeOpRunning = 0, kDeOpIdle = 1 }; enum OpState { kDeOpRunning = 0, kDeOpIdle = 1, kDeOpTerminated };
// Constructor // Constructor
// @param op_connector_size - The size for the output connector of this operator. // @param op_connector_size - The size for the output connector of this operator.
@ -213,11 +213,23 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
// Getter function // Getter function
// @return connector size of current op // @return connector size of current op
virtual int32_t ConnectorSize() const { return out_connector_->size(); } int32_t ConnectorSize() const {
if (!inlined()) {
return out_connector_->size();
}
// Return -1 for inlined op
return -1;
}
// Getter function // Getter function
// @return connector size of current op // @return connector size of current op
virtual int32_t ConnectorCapacity() const { return out_connector_->capacity(); } int32_t ConnectorCapacity() const {
if (!inlined()) {
return out_connector_->size();
}
// Return -1 for inlined op
return -1;
}
// Getter function // Getter function
// @return connector size of child op // @return connector size of child op
@ -228,7 +240,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
int32_t ChildOpConnectorCapacity(int32_t child_index = 0) const { return child_[child_index]->ConnectorCapacity(); } int32_t ChildOpConnectorCapacity(int32_t child_index = 0) const { return child_[child_index]->ConnectorCapacity(); }
// Children Getter // Children Getter
// @return Vector or Children // @return Vector of Children
std::vector<std::shared_ptr<DatasetOp>> Children() const { return child_; } std::vector<std::shared_ptr<DatasetOp>> Children() const { return child_; }
// Base method for NodePass visit. // Base method for NodePass visit.
@ -237,6 +249,14 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
// @return Statue of the node visit // @return Statue of the node visit
virtual Status Accept(NodePass *p, bool *modified); virtual Status Accept(NodePass *p, bool *modified);
// Op name getter
// @return Name of the current Op
virtual std::string Name() const { return "DatasetOp"; }
// Execution Tree getter
// @return Pointer to the ExecutionTree the current op belongs to, no ownership
ExecutionTree *Tree() { return tree_; }
protected: protected:
// Adds a parent operator to this operator // Adds a parent operator to this operator
// @notes External callers do not have access to this function. // @notes External callers do not have access to this function.

@ -13,25 +13,23 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "dataset/engine/datasetops/device_queue_op.h"
#include <iomanip> #include <iomanip>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include "dataset/core/config_manager.h" #include "dataset/core/config_manager.h"
#include "dataset/core/global_context.h" #include "dataset/core/global_context.h"
#include "dataset/engine/datasetops/device_queue_op.h"
#include "dataset/engine/data_buffer.h" #include "dataset/engine/data_buffer.h"
#include "dataset/engine/dataset_iterator.h" #include "dataset/engine/dataset_iterator.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/engine/perf/profiling.h"
#include "dataset/engine/perf/device_queue_tracing.h"
#include "dataset/util/status.h" #include "dataset/util/status.h"
#include "dataset/util/task_manager.h" #include "dataset/util/task_manager.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/util/profiling.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
#define DEVICE_QUEUE_PROFILING_DATA(type, subtype, batch_num, value) \
std::to_string(type) + " " + std::to_string(subtype) + " " + std::to_string(batch_num) + " " + std::to_string(value)
DeviceQueueOp::DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size, DeviceQueueOp::DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size,
int32_t op_connector_size, int64_t num_batch) int32_t op_connector_size, int64_t num_batch)
: PipelineOp(op_connector_size), : PipelineOp(op_connector_size),
@ -101,22 +99,16 @@ Status DeviceQueueOp::SendDataToAscend() {
MS_LOG(INFO) << "Device queue, sending data to Ascend."; MS_LOG(INFO) << "Device queue, sending data to Ascend.";
int64_t total_batch = 0; int64_t total_batch = 0;
bool is_break_loop = false; bool is_break_loop = false;
double batch_start_time, tdt_start_time, end_time; double batch_start_time, end_time;
int32_t batch_cost, tdt_cost; int32_t batch_cost, tdt_cost;
int32_t connector_size = 0; int32_t connector_size = 0;
int32_t connector_capacity; int32_t connector_capacity;
std::shared_ptr<Profiling> profiling_node; std::shared_ptr<DeviceQueueTracing> profiling_node;
bool isProfilingEnable = ProfilingManager::GetInstance().IsProfilingEnable(); bool isProfilingEnable = tree_->GetProfilingManager()->IsProfilingEnable();
if (isProfilingEnable) { if (isProfilingEnable) {
std::string file_name = "critical_point_profiling"; std::shared_ptr<Tracing> node;
// Here can determine performance bottleneck is in pipeline or in tdt. RETURN_IF_NOT_OK(tree_->GetProfilingManager()->GetTracingNode(kDeviceQueueTracingName, &node));
// Context format of this file "type subtype batchnum value" profiling_node = std::dynamic_pointer_cast<DeviceQueueTracing>(node);
// type:0: time, 1: queue depth
// subtype:0: pipeline time, 1: push tdt time, 2: all time
// batchnum: batch number
// value: value of time(ms) or queue depth
profiling_node = std::make_shared<Profiling>(file_name, device_id_);
RETURN_IF_NOT_OK(ProfilingManager::GetInstance().RegisterProfilingNode(&profiling_node));
batch_start_time = ProfilingTime::GetCurMilliSecond(); batch_start_time = ProfilingTime::GetCurMilliSecond();
connector_capacity = ChildOpConnectorCapacity(); connector_capacity = ChildOpConnectorCapacity();
} }
@ -129,29 +121,23 @@ Status DeviceQueueOp::SendDataToAscend() {
TensorRow currRow; TensorRow currRow;
for (int row_id = 0; row_id < current_buffer->NumRows() && !is_break_loop; row_id++) { for (int row_id = 0; row_id < current_buffer->NumRows() && !is_break_loop; row_id++) {
RETURN_IF_NOT_OK(current_buffer->GetRow(row_id, &currRow)); RETURN_IF_NOT_OK(current_buffer->GetRow(row_id, &currRow));
if (isProfilingEnable) { auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost);
tdt_start_time = ProfilingTime::GetCurMilliSecond();
}
auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_);
if (status == TdtStatus::FAILED) { if (status == TdtStatus::FAILED) {
return Status(StatusCode::kTDTPushFailure, "TDT Push Failed"); return Status(StatusCode::kTDTPushFailure, "TDT Push Failed");
} }
if (isProfilingEnable) { if (isProfilingEnable) {
end_time = ProfilingTime::GetCurMilliSecond(); end_time = ProfilingTime::GetCurMilliSecond();
tdt_cost = (int32_t)(end_time - tdt_start_time);
// record push tdt time // record push tdt time
profiling_node->Record(DEVICE_QUEUE_PROFILING_DATA(TIME, TDT_PUSH_TIME, total_batch + 1, tdt_cost)); profiling_node->Record(TIME, TDT_PUSH_TIME, total_batch + 1, tdt_cost);
batch_cost = (int32_t)(end_time - batch_start_time); batch_cost = (int32_t)(end_time - batch_start_time);
// record batch time // record batch time
profiling_node->Record(DEVICE_QUEUE_PROFILING_DATA(TIME, BATCH_TIME, total_batch + 1, batch_cost)); profiling_node->Record(TIME, BATCH_TIME, total_batch + 1, batch_cost);
// record pipeline time // record pipeline time
profiling_node->Record( profiling_node->Record(TIME, PIPELINE_TIME, total_batch + 1, batch_cost - tdt_cost);
DEVICE_QUEUE_PROFILING_DATA(TIME, PIPELINE_TIME, total_batch + 1, batch_cost - tdt_cost));
batch_start_time = end_time; batch_start_time = end_time;
// record connector depth // record connector depth
profiling_node->Record( profiling_node->Record(CONNECTOR_DEPTH, connector_capacity, total_batch + 1, connector_size);
DEVICE_QUEUE_PROFILING_DATA(CONNECTOR_DEPTH, connector_capacity, total_batch + 1, connector_size));
} }
total_batch++; total_batch++;
if (num_batch_ > 0 && total_batch == num_batch_) { if (num_batch_ > 0 && total_batch == num_batch_) {
@ -171,9 +157,7 @@ Status DeviceQueueOp::SendDataToAscend() {
RETURN_IF_NOT_OK(GetNextInput(&current_buffer)); RETURN_IF_NOT_OK(GetNextInput(&current_buffer));
} }
if (isProfilingEnable) { tree_->SetFinished();
profiling_node->SaveToFile();
}
MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << "."; MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << ".";
return Status::OK(); return Status::OK();

@ -140,6 +140,10 @@ class DeviceQueueOp : public PipelineOp {
// @return - Status of the node visit. // @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override; Status Accept(NodePass *p, bool *modified) override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "DeviceQueueOp"; }
private: private:
// Name: checkExceptions(DataBuffer); // Name: checkExceptions(DataBuffer);
// Description: Check whether the dataBuffer meets the condition for performing DeviceQueueOp // Description: Check whether the dataBuffer meets the condition for performing DeviceQueueOp

@ -127,6 +127,10 @@ class FilterOp : public ParallelOp {
// @return - Status of the node visit. // @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override; Status Accept(NodePass *p, bool *modified) override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "FilterOp"; }
private: private:
// predicate_func python callable which returns a boolean value. // predicate_func python callable which returns a boolean value.
py::function predicate_func_; py::function predicate_func_;

@ -177,6 +177,10 @@ class MapOp : public ParallelOp {
// @return - Status of the node visit. // @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override; Status Accept(NodePass *p, bool *modified) override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "MapOp"; }
private: private:
// Local queues where worker threads can pop from. // Local queues where worker threads can pop from.
// Popping directly from the Connector can block if the previous designated threads haven't pop. // Popping directly from the Connector can block if the previous designated threads haven't pop.

@ -107,6 +107,10 @@ class ProjectOp : public PipelineOp {
// @return - Status of the node visit. // @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override; Status Accept(NodePass *p, bool *modified) override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "ProjectOp"; }
private: private:
std::vector<std::string> columns_to_project_; std::vector<std::string> columns_to_project_;
std::vector<int32_t> projected_column_indices_; std::vector<int32_t> projected_column_indices_;

@ -116,6 +116,10 @@ class RenameOp : public PipelineOp {
// @return - Status of the node visit. // @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override; Status Accept(NodePass *p, bool *modified) override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "RenameOp"; }
protected: protected:
// Rename core functionality // Rename core functionality
Status RenameColumns(); Status RenameColumns();

@ -124,9 +124,9 @@ class RepeatOp : public PipelineOp {
// @return - Status of the node visit. // @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override; Status Accept(NodePass *p, bool *modified) override;
virtual int32_t ConnectorSize() const { return child_[0]->ConnectorSize(); } // Op name getter
// @return Name of the current Op
virtual int32_t ConnectorCapacity() const { return child_[0]->ConnectorCapacity(); } std::string Name() const override { return "RepeatOp"; }
private: private:
int32_t max_repeats_; // The number of repeats that the user requested int32_t max_repeats_; // The number of repeats that the user requested

@ -161,6 +161,10 @@ class ShuffleOp : public PipelineOp {
// @return - Status of the node visit. // @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override; Status Accept(NodePass *p, bool *modified) override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "ShuffleOp"; }
private: private:
// Private function to add a new row to the shuffle buffer. // Private function to add a new row to the shuffle buffer.
// @return Status - The error code return // @return Status - The error code return

@ -80,6 +80,10 @@ class SkipOp : public PipelineOp {
// @return - Status of the node visit. // @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override; Status Accept(NodePass *p, bool *modified) override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "SkipOp"; }
private: private:
int32_t max_skips_; // The number of skips that the user requested int32_t max_skips_; // The number of skips that the user requested
int32_t skip_count_; // A counter for the current number of executed skips int32_t skip_count_; // A counter for the current number of executed skips

@ -169,6 +169,10 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
// @return Status - The error code return // @return Status - The error code return
Status AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer); Status AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer);
// Op name getter
// @return Name of the current Op
std::string Name() const { return "CelebAOp"; }
private: private:
// Called first when function is called // Called first when function is called
// @return // @return

@ -155,6 +155,10 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
// @return // @return
static Status CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count); static Status CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count);
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "CifarOp"; }
private: private:
// Initialize Sampler, calls sampler->Init() within // Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return // @return Status - The error code return

@ -127,6 +127,10 @@ class GeneratorOp : public PipelineOp {
// @return - Status of the node visit. // @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override; Status Accept(NodePass *p, bool *modified) override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "GeneratorOp"; }
private: private:
py::function generator_function_; py::function generator_function_;
std::vector<std::string> column_names_; std::vector<std::string> column_names_;

@ -210,6 +210,10 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
// @return - Status of the node visit. // @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override; Status Accept(NodePass *p, bool *modified) override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "ImageFolderOp"; }
private: private:
// Initialize Sampler, calls sampler->Init() within // Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return // @return Status - The error code return

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save