Implemented Callback for Dataset

implment pause in MapOp, added more to callback

add ds_callback

- Initial drop of Python DSCallback

- Pybind DSCallback

- Pybind DSCallback

added callback to mapOp

- de_pipeline DSCallback

- de_pipeline DSCallback

add test case, segfault for now

fix seg fault

- de_pipeline DSCallback

remove 1 line

update callback test case, now works

use builder class for mapOp callback

- de_pipeline DSCallback

- de_pipeline DSCallback

- de_pipeline DSCallback

better test case

minor fix

add comments and minor clean ups

get rid of nullptr in MapOp, use other flag instead

fix a bug ParseMapOp only takes 1 callback

- Added WaitedDSCalabck

refactor callback param

fix text case incorrect number

- added testing

fix cpp test case

- added testing

- revert back lenet changes

- cleanup test_callbacks.py

- cleanup test_callbacks.py

fix CI stage I

fix CI stage II

fix CI and update epoch counter

- add validation
- add more testing  test_callbacks.py

use random data op to do tests

adjust when to call EpochBegin/End

- add repeat with callback

- addressing reviewers' comments

- docstring and CI fixes

- docstring and CI fixes

- docstring and CI fixes

- rebase with upstream/master

fix cpp test case

fix review comments

addr review cmts, add test case
pull/4632/head
Zirui Wu 5 years ago
parent 89cd465268
commit 78c1aa1d96

@ -58,6 +58,7 @@ add_subdirectory(kernels)
add_subdirectory(engine)
add_subdirectory(api)
add_subdirectory(text)
add_subdirectory(callback)
######################################################################
add_dependencies(utils core)
add_dependencies(kernels-image core)
@ -74,6 +75,7 @@ add_dependencies(engine-cache-server core)
add_dependencies(engine-perf core)
add_dependencies(engine-gnn core)
add_dependencies(engine core)
add_dependencies(callback core)
add_dependencies(text core)
add_dependencies(text-kernels core)
add_dependencies(cpp-API core)
@ -87,6 +89,7 @@ endif ()
################### Create _c_dataengine Library ######################
set(submodules
$<TARGET_OBJECTS:core>
$<TARGET_OBJECTS:callback>
$<TARGET_OBJECTS:utils>
$<TARGET_OBJECTS:kernels>
$<TARGET_OBJECTS:kernels-image>
@ -135,14 +138,14 @@ endif()
target_link_libraries(_c_dataengine PRIVATE mindspore mindspore_gvar)
if (${CMAKE_SYSTEM_NAME} MATCHES "Windows")
if (ENABLE_PYTHON)
target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module ${PYTHON_LIBRARIES} mindspore::protobuf ${SECUREC_LIBRARY})
target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module ${PYTHON_LIBRARIES} mindspore::protobuf ${SECUREC_LIBRARY})
else()
target_link_libraries(_c_dataengine PRIVATE mindspore::protobuf ${SECUREC_LIBRARY})
endif()
else()
set(ICU_LIB mindspore::icuuc mindspore::icudata mindspore::icui18n)
if (ENABLE_PYTHON)
target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module -ldl mindspore::protobuf ${SECUREC_LIBRARY})
target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module -ldl mindspore::protobuf ${SECUREC_LIBRARY})
else()
target_link_libraries(_c_dataengine PRIVATE -ldl mindspore::protobuf ${SECUREC_LIBRARY})
endif()

@ -7,6 +7,7 @@ if (ENABLE_PYTHON)
python/bindings.cc
python/bindings/dataset/engine/cache/bindings.cc
python/bindings/dataset/core/bindings.cc
python/bindings/dataset/callback/bindings.cc
python/bindings/dataset/kernels/data/bindings.cc
python/bindings/dataset/kernels/bindings.cc
python/bindings/dataset/engine/datasetops/bindings.cc

@ -0,0 +1,45 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind11/pybind11.h"
#include "pybind11/stl_bind.h"
#include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/callback/py_ds_callback.h"
#include "minddata/dataset/callback/ds_callback.h"
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(PyDSCallback, 0, ([](const py::module *m) {
(void)py::class_<PyDSCallback, std::shared_ptr<PyDSCallback>>(*m, "PyDSCallback")
.def(py::init<int32_t>())
.def("set_begin", &PyDSCallback::setBegin)
.def("set_end", &PyDSCallback::setEnd)
.def("set_epoch_begin", &PyDSCallback::setEpochBegin)
.def("set_epoch_end", &PyDSCallback::setEpochEnd)
.def("set_step_begin", &PyDSCallback::setStepBegin)
.def("set_step_end", &PyDSCallback::setStepEnd);
}));
PYBIND_REGISTER(CallbackParam, 0, ([](const py::module *m) {
(void)py::class_<CallbackParam, std::shared_ptr<CallbackParam>>(*m, "CallbackParam")
.def(py::init<int64_t, int64_t, int64_t>())
.def_readonly("cur_epoch_num", &CallbackParam::cur_epoch_num_)
.def_readonly("cur_step_num_in_epoch", &CallbackParam::cur_epoch_step_num_)
.def_readonly("cur_step_num", &CallbackParam::cur_step_num_);
}));
} // namespace dataset
} // namespace mindspore

@ -20,6 +20,7 @@
#include <map>
#include "utils/ms_utils.h"
#include "minddata/dataset/callback/py_ds_callback.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/dataset_iterator.h"
@ -738,8 +739,13 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *
(void)map_builder.SetTensorFuncs(std::move(tensor_op_list));
} else if (key == "cache") {
cache_client = value.cast<std::shared_ptr<CacheClient>>();
} else if (key == "callbacks") {
std::vector<std::shared_ptr<DSCallback>> callbacks;
std::transform(value.begin(), value.end(), std::back_inserter(callbacks),
[](py::handle cb) { return cb.cast<std::shared_ptr<PyDSCallback>>(); });
(void)map_builder.AddCallbacks(callbacks);
} else {
RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key);
RETURN_STATUS_UNEXPECTED("Error in parsing MapOp: Unhandled key: " + key);
}
}
}

@ -0,0 +1,14 @@
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
if (ENABLE_PYTHON)
add_library(callback OBJECT
callback_manager.cc
py_ds_callback.cc
)
else ()
add_library(callback OBJECT
callback_manager.cc
)
endif ()

@ -0,0 +1,160 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/callback/callback_manager.h"
#include "minddata/dataset/callback/ds_callback.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/engine/datasetops/dataset_op.h"
namespace mindspore {
namespace dataset {
void CallbackManager::AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks) {
callbacks_.insert(callbacks_.end(), callbacks.begin(), callbacks.end());
}
Status CallbackManager::Init(std::shared_ptr<DatasetOp> op) {
RETURN_UNEXPECTED_IF_NULL(op);
op_ = op;
// turn the flag on if callback is set
enabled_ = !callbacks_.empty();
// error check for each of the callbacks
for (auto &cb : callbacks_) {
CHECK_FAIL_RETURN_UNEXPECTED(cb->step_size() > 0, "callback step_size needs to be greater than 0.");
}
return Status::OK();
}
Status CallbackManager::Begin(const CallbackParam &cb_param) {
RETURN_OK_IF_TRUE(!enabled_);
std::vector<size_t> callback_inds;
// go through all callback functions to see if each function is needed
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
if (callbacks_[ind]->IsBeginNeeded()) callback_inds.push_back(ind);
}
// return Status::OK() if no begin is needed
RETURN_OK_IF_TRUE(callback_inds.empty());
RETURN_IF_NOT_OK(op_->PauseFromMaster());
// Now do the actual callback
for (size_t ind : callback_inds) {
RETURN_IF_NOT_OK(callbacks_[ind]->DSBegin(cb_param));
}
return Status::OK();
}
Status CallbackManager::EpochBegin(const CallbackParam &cb_param) {
RETURN_OK_IF_TRUE(!enabled_);
std::vector<size_t> callback_inds;
// go through all callback functions to see if each function is needed
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
if (callbacks_[ind]->IsEpochBeginNeeded()) callback_inds.push_back(ind);
}
// return Status::OK() if no epoch_begin is needed
RETURN_OK_IF_TRUE(callback_inds.empty());
RETURN_IF_NOT_OK(op_->PauseFromMaster());
// Now do the actual callback
for (size_t ind : callback_inds) {
RETURN_IF_NOT_OK(callbacks_[ind]->DSEpochBegin(cb_param));
}
return Status::OK();
}
Status CallbackManager::StepBegin(const CallbackParam &cb_param) {
RETURN_OK_IF_TRUE(!enabled_);
std::vector<size_t> callback_inds;
// go through all callback functions to see if each function is needed
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
if (callbacks_[ind]->IsNStepBeginNeeded() && (cb_param.cur_epoch_step_num_ - 1) % callbacks_[ind]->step_size() == 0)
callback_inds.push_back(ind);
}
// return Status::OK() if no step_begin is needed
RETURN_OK_IF_TRUE(callback_inds.empty());
RETURN_IF_NOT_OK(op_->PauseFromMaster());
// Now do the actual callback
for (size_t ind : callback_inds) {
RETURN_IF_NOT_OK(callbacks_[ind]->DSNStepBegin(cb_param));
}
return Status::OK();
}
Status CallbackManager::End(const CallbackParam &cb_param) {
RETURN_OK_IF_TRUE(!enabled_);
std::vector<size_t> callback_inds;
// go through all callback functions to see if each function is needed
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
if (callbacks_[ind]->IsEndNeeded()) callback_inds.push_back(ind);
}
// return Status::OK() if no end is needed
RETURN_OK_IF_TRUE(callback_inds.empty());
RETURN_IF_NOT_OK(op_->PauseFromMaster());
// Now do the actual callback
for (size_t ind : callback_inds) {
RETURN_IF_NOT_OK(callbacks_[ind]->DSEnd(cb_param));
}
return Status::OK();
}
Status CallbackManager::EpochEnd(const CallbackParam &cb_param) {
RETURN_OK_IF_TRUE(!enabled_);
std::vector<size_t> callback_inds;
// go through all callback functions to see if each function is needed
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
if (callbacks_[ind]->IsEpochEndNeeded()) callback_inds.push_back(ind);
}
// return Status::OK() if no epoch_end is needed
RETURN_OK_IF_TRUE(callback_inds.empty());
RETURN_IF_NOT_OK(op_->PauseFromMaster());
// Now do the actual callback
for (size_t ind : callback_inds) {
RETURN_IF_NOT_OK(callbacks_[ind]->DSEpochEnd(cb_param));
}
return Status::OK();
}
Status CallbackManager::StepEnd(const CallbackParam &cb_param) {
RETURN_OK_IF_TRUE(!enabled_);
std::vector<size_t> callback_inds;
// go through all callback functions to see if each function is needed
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
if (callbacks_[ind]->IsNStepEndNeeded() && (cb_param.cur_epoch_step_num_ - 1) % callbacks_[ind]->step_size() == 0)
callback_inds.push_back(ind);
}
// return Status::OK() if no step_end is needed
RETURN_OK_IF_TRUE(callback_inds.empty());
RETURN_IF_NOT_OK(op_->PauseFromMaster());
// Now do the actual callback
for (size_t ind : callback_inds) {
RETURN_IF_NOT_OK(callbacks_[ind]->DSNStepEnd(cb_param));
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,79 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_MANAGER_H
#define MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_MANAGER_H
#include <memory>
#include <vector>
#include "minddata/dataset/callback/ds_callback.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
// forward declare to avoid cyclic include of dataset_op.h
class DatasetOp;
/// This class manages all the callbacks that are associated with a single DatasetOp. For now, only MapOp supports this.
class CallbackManager {
public:
/// CallbackManager default constructor. Init needs to be called before using the created instance.
CallbackManager() : enabled_(false) {}
/// \brief
/// \param [in] callbacks list of callbacks to perform
void AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks);
/// \brief DatasetOp needs to call Init if it wishes to use callback, Init will set enabled_ to true
/// \param[in] op, this pointer is used for Callback Manager to Pause Worker threads
/// \return Status
Status Init(std::shared_ptr<DatasetOp> op);
/// \brief callback function called at the start of the first row
/// \return Status
Status Begin(const CallbackParam &);
/// \brief callback function called at the start of each epoch
/// \return Status
Status EpochBegin(const CallbackParam &);
/// \brief callback function called at the start of each row
/// \return Status
Status StepBegin(const CallbackParam &);
/// \brief callback function called after the last row is processed
/// \return Status
Status End(const CallbackParam &);
/// \brief callback function called at the end of each epoch
/// \return Status
Status EpochEnd(const CallbackParam &);
/// \brief callback function called at the the end of each row
/// \return Status
Status StepEnd(const CallbackParam &);
private:
bool enabled_; // flag to enable callback, if false, all functions would return immediately
std::shared_ptr<DatasetOp> op_; // back pointer to DatasetOp, each DatasetOp has only 1 CallbackManager
std::vector<std::shared_ptr<DSCallback>> callbacks_; // list of callbacks the DatasetOp needs to call
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_MANAGER_H

@ -0,0 +1,40 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_PARAM_H
#define MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_PARAM_H
#include <nlohmann/json.hpp>
namespace mindspore {
namespace dataset {
/// Callback Param is the object a DatasetOp uses to pass run-time information to user defined function.
/// This is a prototype for now, more fields will be added
class CallbackParam {
public:
CallbackParam(int64_t epoch_num, int64_t cur_epoch_step, int64_t total_step_num)
: cur_epoch_num_(epoch_num), cur_epoch_step_num_(cur_epoch_step), cur_step_num_(total_step_num) {}
// these are constant public fields for easy access and consistency with python cb_param
// the names and orders are consistent with batchInfo
const int64_t cur_epoch_num_; // current epoch
const int64_t cur_epoch_step_num_; // step number of the current epoch
const int64_t cur_step_num_; // step number since the first row
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_PARAM_H

@ -0,0 +1,100 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_DS_CALLBACK_H
#define MINDSPORE_CCSRC_MINDDATA_DATASET_DS_CALLBACK_H
#include <memory>
#include <utility>
#include <vector>
#include "minddata/dataset/callback/callback_param.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class DSCallback {
public:
/// \brief constructor of DSCallback, this is the base class for all front end specific callbacks
/// \param step_size number of steps to call DSNStepBegin()
explicit DSCallback(int32_t step_size = 1) : step_size_(step_size) {}
/// \brief actual callback function for begin, needs to be overridden in the derived class
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
/// \return Status
virtual Status DSBegin(const CallbackParam &cb_param) = 0;
/// \brief actual callback function for epoch_begin, needs to be overridden in the derived class
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
/// \return Status
virtual Status DSEpochBegin(const CallbackParam &cb_param) = 0;
/// \brief actual callback function for step_begin, needs to be overridden in the derived class
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
/// \return Status
virtual Status DSNStepBegin(const CallbackParam &cb_param) = 0;
/// \brief actual callback function for end, needs to be overridden in the derived class
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
/// \return Status
virtual Status DSEnd(const CallbackParam &cb_param) = 0;
/// \brief actual callback function epoch_end begin, needs to be overridden in the derived class
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
/// \return Status
virtual Status DSEpochEnd(const CallbackParam &cb_param) = 0;
/// \brief actual callback function for step_end, needs to be overridden in the derived class
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
/// \return Status
virtual Status DSNStepEnd(const CallbackParam &cb_param) = 0;
/// \brief predicate function, whether begin callback is needed
/// \return bool
virtual bool IsBeginNeeded() = 0;
/// \brief predicate function, whether epoch_begin callback is needed
/// \return bool
virtual bool IsEpochBeginNeeded() = 0;
/// \brief predicate function, whether step_begin callback is needed
/// \return bool
virtual bool IsNStepBeginNeeded() = 0;
/// \brief predicate function, whether end callback is needed
/// \return bool
virtual bool IsEndNeeded() = 0;
/// \brief predicate function, whether epoch_end callback is needed
/// \return bool
virtual bool IsEpochEndNeeded() = 0;
/// \brief predicate function, whether step_end callback is needed
/// \return bool
virtual bool IsNStepEndNeeded() = 0;
/// \brief getter
/// \return step_size
int32_t step_size() const { return step_size_; }
protected:
int32_t step_size_; // step begin/end will be called every step_size_
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_DS_CALLBACK_H

@ -0,0 +1,86 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/callback/callback_manager.h"
#include "minddata/dataset/callback/py_ds_callback.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
Status PyDSCallback::DSBegin(const CallbackParam &cb_param) {
return PyDSCallback::ExecutePyfunc(begin_func_, cb_param);
}
Status PyDSCallback::DSEpochBegin(const CallbackParam &cb_param) {
return PyDSCallback::ExecutePyfunc(epoch_begin_func_, cb_param);
}
Status PyDSCallback::DSNStepBegin(const CallbackParam &cb_param) {
return PyDSCallback::ExecutePyfunc(step_begin_func_, cb_param);
}
Status PyDSCallback::DSEnd(const CallbackParam &cb_param) { return PyDSCallback::ExecutePyfunc(end_func_, cb_param); }
Status PyDSCallback::DSEpochEnd(const CallbackParam &cb_param) {
return PyDSCallback::ExecutePyfunc(epoch_end_func_, cb_param);
}
Status PyDSCallback::DSNStepEnd(const CallbackParam &cb_param) {
return PyDSCallback::ExecutePyfunc(step_end_func_, cb_param);
}
bool PyDSCallback::IsBeginNeeded() { return begin_needed_; }
bool PyDSCallback::IsEpochBeginNeeded() { return epoch_begin_needed_; }
bool PyDSCallback::IsNStepBeginNeeded() { return step_begin_needed_; }
bool PyDSCallback::IsNStepEndNeeded() { return step_end_needed_; }
bool PyDSCallback::IsEpochEndNeeded() { return epoch_end_needed_; }
bool PyDSCallback::IsEndNeeded() { return end_needed_; }
Status PyDSCallback::ExecutePyfunc(py::function f, const CallbackParam &cb_param) {
{
// Acquire Python GIL
py::gil_scoped_acquire gil_acquire;
if (Py_IsInitialized() == 0) {
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
}
f(cb_param);
}
return Status::OK();
}
void PyDSCallback::setBegin(py::function f) {
begin_func_ = f;
begin_needed_ = true;
}
void PyDSCallback::setEnd(py::function f) {
end_func_ = f;
end_needed_ = true;
}
void PyDSCallback::setEpochBegin(py::function f) {
epoch_begin_func_ = f;
epoch_begin_needed_ = true;
}
void PyDSCallback::setEpochEnd(py::function f) {
epoch_end_func_ = f;
epoch_end_needed_ = true;
}
void PyDSCallback::setStepBegin(py::function f) {
step_begin_func_ = f;
step_begin_needed_ = true;
}
void PyDSCallback::setStepEnd(py::function f) {
step_end_func_ = f;
step_end_needed_ = true;
}
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,130 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_PY_DS_CALLBACK_H
#define MINDSPORE_CCSRC_MINDDATA_DATASET_PY_DS_CALLBACK_H
#include <memory>
#include <utility>
#include <vector>
#include "minddata/dataset/callback/ds_callback.h"
#include "minddata/dataset/util/status.h"
#include "pybind11/pybind11.h"
namespace mindspore {
namespace dataset {
namespace py = pybind11;
class PyDSCallback : public DSCallback {
public:
/// \brief constructor for PyDSCallback. This callback is for python front end
explicit PyDSCallback(int32_t step_size = 1)
: DSCallback(step_size),
begin_needed_(false),
epoch_begin_needed_(false),
step_begin_needed_(false),
end_needed_(false),
epoch_end_needed_(false),
step_end_needed_(false) {}
void setBegin(py::function f);
void setEnd(py::function f);
void setEpochBegin(py::function f);
void setEpochEnd(py::function f);
void setStepBegin(py::function f);
void setStepEnd(py::function f);
/// \brief actual callback function for begin, needs to be overridden in the derived class
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
/// \return Status
Status DSBegin(const CallbackParam &cb_param) override;
/// \brief actual callback function for epoch_begin, needs to be overridden in the derived class
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
/// \return Status
Status DSEpochBegin(const CallbackParam &cb_param) override;
/// \brief actual callback function for step_begin, needs to be overridden in the derived class
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
/// \return Status
Status DSNStepBegin(const CallbackParam &cb_param) override;
/// \brief actual callback function for end, needs to be overridden in the derived class
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
/// \return Status
Status DSEnd(const CallbackParam &cb_param) override;
/// \brief actual callback function epoch_end begin, needs to be overridden in the derived class
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
/// \return Status
Status DSEpochEnd(const CallbackParam &cb_param) override;
/// \brief actual callback function for step_end, needs to be overridden in the derived class
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
/// \return Status
Status DSNStepEnd(const CallbackParam &cb_param) override;
/// \brief predicate function, whether begin callback is needed
/// \return bool
bool IsBeginNeeded() override;
/// \brief predicate function, whether epoch_begin callback is needed
/// \return bool
bool IsEpochBeginNeeded() override;
/// \brief predicate function, whether step_begin callback is needed
/// \return bool
bool IsNStepBeginNeeded() override;
/// \brief predicate function, whether end callback is needed
/// \return bool
bool IsEndNeeded() override;
/// \brief predicate function, whether epoch_end callback is needed
/// \return bool
bool IsEpochEndNeeded() override;
/// \brief predicate function, whether step_end callback is needed
/// \return bool
bool IsNStepEndNeeded() override;
/// \brief helper function to acquire GIL then execute a pyfunc
/// \param f the python function
/// \param cb_param
/// \return Status
static Status ExecutePyfunc(py::function f, const CallbackParam &cb_param);
private:
py::function begin_func_;
py::function epoch_begin_func_;
py::function step_begin_func_;
py::function end_func_;
py::function epoch_end_func_;
py::function step_end_func_;
bool begin_needed_;
bool epoch_begin_needed_;
bool step_begin_needed_;
bool end_needed_;
bool epoch_end_needed_;
bool step_end_needed_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_PY_DS_CALLBACK_H

@ -21,6 +21,8 @@
#include <string>
#include <unordered_map>
#include <vector>
#include "minddata/dataset/callback/callback_manager.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/engine/db_connector.h"
#include "minddata/dataset/util/status.h"
@ -358,6 +360,14 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return boolean returns true if it's last iteration
bool IsLastIteration() { return op_total_repeats_ == op_current_repeats_ + 1; }
/// This function is only intended to be called by CallbackManager within the master thread of ParallelOp
/// The expected behavior is this, when this function is invoked, this function will block until all the workers
/// have finished their remaining work and go to sleep. Since all ParallelOps use a QueueList to sync with master.
/// They would automatically wait on the QueueList when they are done. Hence, for now, a Unpause() function is not
/// needed. Only parallelOp needs to override this function.
/// \return Status
virtual Status PauseFromMaster() { return Status::OK(); }
protected:
/// \brief Removes a parent operator from this operator
/// \notes External callers do not have access to this function
@ -394,6 +404,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
std::unique_ptr<DbConnector> out_connector_; // Output Connector
std::unordered_map<std::string, int32_t> column_name_id_map_; // Mapping between col index and col name
std::mutex column_name_map_mutex_; // For protecting shared access to the column map
CallbackManager callback_manager_; // Manages callbacks associated with a DatasetOp
private:
/// Sets the operator id.

@ -15,25 +15,23 @@
*/
#include <algorithm>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <vector>
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/callback/callback_param.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/engine/db_connector.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
#include "minddata/dataset/engine/datasetops/map_op/cpu_map_job.h"
#include "minddata/dataset/engine/datasetops/map_op/gpu_map_job.h"
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "utils/log_adapter.h"
#include "minddata/dataset/util/task_manager.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
@ -58,6 +56,7 @@ Status MapOp::Builder::Build(std::shared_ptr<MapOp> *ptr) {
RETURN_IF_NOT_OK(sanityCheck());
*ptr = std::make_shared<MapOp>(std::move(build_in_col_names_), std::move(build_out_col_names_),
std::move(build_tensor_funcs_), build_num_workers_, build_op_connector_size_);
(*ptr)->callback_manager_.AddCallbacks(std::move(builder_callbacks_));
return Status::OK();
}
@ -164,7 +163,10 @@ Status MapOp::GenerateWorkerJob(const std::unique_ptr<MapWorkerJob> *worker_job)
Status MapOp::operator()() {
// Create and register the local queues.
local_queues_.Init(num_workers_, oc_queue_size_);
// init callback
RETURN_IF_NOT_OK(callback_manager_.Init(shared_from_this()));
Status rc = local_queues_.Register(tree_->AllTasks());
RETURN_IF_NOT_OK(master_pause_wp_.Register(tree_->AllTasks()));
if (rc.IsError()) {
TaskManager::FindMe()->Post();
return rc;
@ -175,28 +177,51 @@ Status MapOp::operator()() {
// Synchronize with TaskManager
TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(rc);
// num_buffers received, including eoe, num_epoch, num_step of current epoch
int64_t num_buf = 0, ep_step = 0, total_step = 0;
RETURN_IF_NOT_OK(callback_manager_.Begin(CallbackParam(0, ep_step, total_step)));
int64_t que_id = 0;
std::unique_ptr<DataBuffer> buff;
bool is_eof = false;
// Drain output connector of the previous op, generate jobs for worker threads, and distribute them via local queues
// Stop when all worker threads are finished (received EOF)
while (!is_eof) {
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0));
is_eof = buff->eof();
// Create an empty map worker job to be populated by a databuffer and map jobs
std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>();
worker_job->databuffer = std::move(buff);
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0));
while (!buff->eof()) {
if (op_current_repeats_ % op_num_repeats_per_epoch() == 0) {
RETURN_IF_NOT_OK(callback_manager_.EpochBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
}
while (!buff->eoe()) {
ep_step++;
total_step++;
// Create an empty map worker job to be populated by a databuffer and map jobs
RETURN_IF_NOT_OK(callback_manager_.StepBegin(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(buff));
// Populate map worker job for a worker to execute
RETURN_IF_NOT_OK(GenerateWorkerJob(&worker_job));
// Populate map worker job for a worker to execute
RETURN_IF_NOT_OK(GenerateWorkerJob(&worker_job));
// Push map worker job to the corresponding worker's queue
RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job)));
RETURN_IF_NOT_OK(callback_manager_.StepEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
// Push map worker job to the corresponding worker's queue
RETURN_IF_NOT_OK(local_queues_[que_id]->Add(std::move(worker_job)));
que_id = (que_id + 1) % num_workers_;
}
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0));
}
// send the eoe buffer to worker
// reset epoch_step when a new epoch is about to start
if ((op_current_repeats_ + 1) % op_num_repeats_per_epoch() == 0) {
RETURN_IF_NOT_OK(callback_manager_.EpochEnd(CallbackParam(op_current_epochs_ + 1, ep_step, total_step)));
ep_step = 0;
}
std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(buff));
RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job)));
UpdateRepeatAndEpochCounter();
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buff, 0));
}
// the last eoe increments the eoe count by 1, but this shouldn't be reflected on End() callback
// RETURN_IF_NOT_OK(callback_manager_.End(CallbackParam(op_current_epochs_, ep_step, total_step)));
// handle eof logic
std::unique_ptr<MapWorkerJob> worker_job = std::make_unique<MapWorkerJob>(std::move(buff));
RETURN_IF_NOT_OK(local_queues_[num_buf++ % num_workers_]->Add(std::move(worker_job)));
return Status::OK();
}
@ -213,25 +238,19 @@ Status MapOp::WorkerEntry(int32_t worker_id) {
// Fetch next data buffer and map job list
RETURN_IF_NOT_OK(FetchNextWork(worker_id, &in_buffer, &job_list));
// Sanity check the databuffer.
// Special case: if there's more threads than buffers, some threads simply get the final control
// messages (eoe/eof), and so they will not perform the check.
if (!in_buffer->eoe() && !in_buffer->eof()) {
int32_t num_rows = in_buffer->NumRows();
int32_t num_cols = in_buffer->NumCols();
if (num_rows == 0 || num_cols == 0) {
RETURN_STATUS_UNEXPECTED("MapOp is getting an empty DataBuffer.");
}
}
// Now that init work is done, drop into the main fetching loop.
// Map op does not use child iterator, and it needs to manually handle eoe and eof's itself
// rather than use the base-class defaults.
while (true) {
// Handle EOE and EOF ourselves. Implicit eoe/eof handling in GetNextInput does not work
// with Performance Mode design.
if (in_buffer->eoe()) {
UpdateRepeatAndEpochCounter();
// handle the pause logic. Pause is triggered when an buffer id of -1 with no special flag and no row is received
if (in_buffer->id() == -1 && in_buffer->buffer_flags() == DataBuffer::kDeBFlagNone && in_buffer->NumRows() == 0) {
// when worker receives the signal from master thread, it increments a atomic int
// the last guy who increments the counter, wakes up master thread
if (++num_workers_paused_ == num_workers_) master_pause_wp_.Set();
// this will block the worker until master thread gives it a new work
RETURN_IF_NOT_OK(FetchNextWork(worker_id, &in_buffer, &job_list));
continue;
} else if (in_buffer->eoe()) {
// Calling base class EoeReceived to forward eoe buffer.
RETURN_IF_NOT_OK(EoeReceived(worker_id));
// Fetch next data buffer and map job list
@ -243,6 +262,7 @@ Status MapOp::WorkerEntry(int32_t worker_id) {
break;
}
CHECK_FAIL_RETURN_UNEXPECTED(in_buffer->NumRows() * in_buffer->NumCols() != 0, "MapOp got an empty DataBuffer.");
std::unique_ptr<TensorQTable> new_tensor_table(std::make_unique<TensorQTable>());
// Perform the compute function of TensorOp(s) and store the result in new_tensor_table.
RETURN_IF_NOT_OK(WorkerCompute(in_buffer.get(), new_tensor_table.get(), job_list));
@ -281,9 +301,9 @@ Status MapOp::WorkerCompute(DataBuffer *in_buffer, TensorQTable *new_tensor_tabl
std::vector<TensorRow> result_table;
// Executing the list of jobs
for (size_t i = 0; i < job_list.size(); i++) {
// Executre MapJob.
// Execute MapJob.
RETURN_IF_NOT_OK(job_list[i]->Run(job_input_table, &result_table));
// Assign the pocessed data as an input for the next job processing, except for the last TensorOp in the list.
// Assign the processed data as an input for the next job processing, except for the last TensorOp in the list.
if (i + 1 < job_list.size()) {
job_input_table = std::move(result_table);
}
@ -428,5 +448,20 @@ Status MapOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<MapOp>(), modified);
}
Status MapOp::PauseFromMaster() {
// reset num_paused workers to 0
num_workers_paused_ = 0;
for (int32_t wkr_id = 0; wkr_id < num_workers_; wkr_id++) {
// a special buffer (id=-1, empty, none flag) is used to signal that worker needs to pause.
RETURN_IF_NOT_OK(local_queues_[wkr_id]->Add(
std::make_unique<MapWorkerJob>(std::make_unique<DataBuffer>(-1, DataBuffer::kDeBFlagNone))));
}
// wait until all workers are done processing their work in local_queue_
RETURN_IF_NOT_OK(master_pause_wp_.Wait());
// clear the WaitPost for the next Wait()
master_pause_wp_.Clear();
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -16,15 +16,19 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_MAP_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_MAP_OP_H_
#include <atomic>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "minddata/dataset/callback/ds_callback.h"
#include "minddata/dataset/engine/datasetops/map_op/map_job.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/queue.h"
#include "minddata/dataset/engine/datasetops/map_op/map_job.h"
#include "minddata/dataset/util/wait_post.h"
namespace mindspore {
namespace dataset {
@ -108,6 +112,13 @@ class MapOp : public ParallelOp {
return *this;
}
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder &AddCallbacks(const std::vector<std::shared_ptr<DSCallback>> &callbacks) {
builder_callbacks_.insert(builder_callbacks_.end(), callbacks.begin(), callbacks.end());
return *this;
}
// The builder "build" method creates the final object.
// @param ptr The shared_ptr to the new MapOp object
// @return Status
@ -116,6 +127,7 @@ class MapOp : public ParallelOp {
private:
std::vector<std::string> build_in_col_names_;
std::vector<std::string> build_out_col_names_;
std::vector<std::shared_ptr<DSCallback>> builder_callbacks_;
std::vector<std::shared_ptr<TensorOp>> build_tensor_funcs_;
int32_t build_num_workers_;
int32_t build_op_connector_size_;
@ -186,6 +198,7 @@ class MapOp : public ParallelOp {
// A unit of job for map worker thread.
// MapWorkerJob holds a list of MapJob where each MapJob can be a CpuMapJob, GpuMapJob or DvppMapJob.
struct MapWorkerJob {
explicit MapWorkerJob(std::unique_ptr<DataBuffer> db) : databuffer(std::move(db)) {}
std::vector<std::shared_ptr<MapJob>> jobs;
std::unique_ptr<DataBuffer> databuffer;
};
@ -215,6 +228,12 @@ class MapOp : public ParallelOp {
// Indices of the columns to process.
std::vector<size_t> to_process_indices_;
// wait post used to perform the pausing logic in MapOp
WaitPost master_pause_wp_;
// count number of workers that have signaled master
std::atomic_int num_workers_paused_;
// Private function for worker/thread to loop continuously. It comprises the main
// logic of MapOp: getting the data from previous Op, validating user specified column names,
// applying a list of TensorOps to each of the data, process the results and then
@ -247,6 +266,13 @@ class MapOp : public ParallelOp {
// Private function for initializing private variables such as in_columns_, out_columns_.
// @return - Status
Status InitPrivateVariable(std::unordered_map<std::string, int32_t> *col_name_id_map);
// This function should only be called from master thread. It intends to suspend the operation of all workers and
// have them wait on the QueueList. Master thread would send a token to each worker then wait on a WaitPost.
// Workers upon receiving the suspension token from master thread, increment an atomic count, the last worker
// who does the increment wakes up the master.
// @return - Status
Status PauseFromMaster() override;
};
} // namespace dataset
} // namespace mindspore

@ -34,7 +34,7 @@ class Semaphore {
/// \brief Decrement the internal counter. Will be blocked if the value is 0.
/// \return Error code. Can get interrupt.
Status P();
/// \brief Increment the internal counter. Wakeup on of the watiers if any.
/// \brief Increment the internal counter. Wakeup on of the waiters if any.
void V();
/// \brief Peek the internal value
/// \return The internal value

@ -59,6 +59,13 @@ namespace dataset {
} \
} while (false)
#define RETURN_OK_IF_TRUE(_condition) \
do { \
if (_condition) { \
return Status::OK(); \
} \
} while (false)
enum class StatusCode : char {
kOK = 0,
kOutOfMemory = 1,

@ -0,0 +1,18 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""init file for python callback"""
from .ds_callback import DSCallback, WaitedDSCallback
__all__ = ["DSCallback", "WaitedDSCallback"]

@ -0,0 +1,232 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Python callback class
"""
import threading
from mindspore._c_dataengine import PyDSCallback
from mindspore.train.callback import Callback
from .validators import check_callback
class DSCallback:
"""
Abstract base class used to build a dataset callback class.
Args:
step_size (int, optional): The number of steps before the step_begin and step_end are called (Default=1).
Examples:
>>> class PrintInfo(DSCallback):
>>> def ds_epoch_end(self, ds_run_context):
>>> print(cb_params.cur_epoch_num)
>>> print(cb_params.cur_step_num)
>>>
>>> data = data.map(operations=op, callbacks=PrintInfo())
"""
@check_callback
def __init__(self, step_size=1):
self.step_size = step_size
def ds_begin(self, ds_run_context):
"""
Called before the data pipeline is started.
Args:
ds_run_context (RunContext): Include some information of the pipeline.
"""
def ds_epoch_begin(self, ds_run_context):
"""
Called before a new epoch is started.
Args:
ds_run_context (RunContext): Include some information of the pipeline.
"""
def ds_epoch_end(self, ds_run_context):
"""
Called after an epoch is finished.
Args:
ds_run_context (RunContext): Include some information of the pipeline.
"""
def ds_step_begin(self, ds_run_context):
"""
Called before n steps are started.
Args:
ds_run_context (RunContext): Include some information of the pipeline.
"""
def ds_step_end(self, ds_run_context):
"""
Called after n steps are finished.
Args:
ds_run_context (RunContext): Include some information of the pipeline.
"""
def create_runtime_obj(self):
"""
Creates a runtime (C++) object from the callback methods defined by the user.
Returns: _c_dataengine.PyDSCallback
"""
c_cb = PyDSCallback(self.step_size)
at_least_one = False
if self.__class__.ds_begin != DSCallback.ds_begin:
c_cb.set_begin(self.ds_begin)
at_least_one = True
if self.__class__.ds_epoch_begin != DSCallback.ds_epoch_begin:
c_cb.set_epoch_begin(self.ds_epoch_begin)
at_least_one = True
if self.__class__.ds_epoch_end != DSCallback.ds_epoch_end:
c_cb.set_epoch_end(self.ds_epoch_end)
at_least_one = True
if self.__class__.ds_step_begin != DSCallback.ds_step_begin:
c_cb.set_step_begin(self.ds_step_begin)
at_least_one = True
if self.__class__.ds_step_end != DSCallback.ds_step_end:
c_cb.set_step_end(self.ds_step_end)
at_least_one = True
if not at_least_one:
raise AttributeError("Provided Callback class did not override any of the 6 callback methods.")
return c_cb
class WaitedDSCallback(Callback, DSCallback):
"""
Abstract base class used to build a dataset callback class that are synchronized with the training callback.
This class can be used to execute a user defined logic right after the previous step or epoch.
For example, one augmentation needs the loss from the previous trained epoch to update some of its parameters.
Examples:
>>> my_cb = MyWaitedCallback(32)
>>> data = data.map(operations=AugOp(), callbacks=my_cb)
>>> data = data.batch(32)
>>> # define the model
>>> model.train(epochs, data, callbacks=[my_cb])
Args:
step_size: the number of rows in each step.
Usually the step size will be equal to the batch size (Default=1)
"""
def __init__(self, step_size=1):
super().__init__()
self.step_size = step_size
self.step_event = threading.Event()
self.step_run_context = None
self.epoch_event = threading.Event()
self.epoch_run_context = None
def sync_epoch_begin(self, train_run_context, ds_run_context):
"""
Called before a new dataset epoch is started and after the previous training epoch is ended.
Args:
train_run_context: Include some information of the model with feedback from the previous epoch.
ds_run_context: Include some information of the dataset pipeline.
"""
def sync_step_begin(self, train_run_context, ds_run_context):
"""
Called before a new dataset step is started and after the previous training step is ended.
Args:
train_run_context: Include some information of the model with feedback from the previous step.
ds_run_context: Include some information of the dataset pipeline.
"""
def epoch_end(self, run_context):
"""
Internal method, do not call/override. Defines epoch_end of Callback to release the wait in ds_epoch_begin.
Args:
run_context: Include some information of the model.
"""
self.epoch_run_context = run_context
self.epoch_event.set()
self.epoch_event.clear()
def ds_epoch_begin(self, ds_run_context):
"""
Internal method, do not call/override. Defines ds_epoch_begin of DSCallback to wait for MS epoch_end callback.
Args:
ds_run_context: Include some information of the pipeline.
"""
if ds_run_context.cur_epoch_num > 1:
if self.epoch_run_context is None:
self.epoch_event.wait()
self.sync_epoch_begin(self.epoch_run_context, ds_run_context)
self.epoch_run_context = None
def step_end(self, run_context):
"""
Internal method, do not call/override. Defines step_end of Callback to release the wait in ds_step_begin.
Args:
run_context: Include some information of the model.
"""
self.step_run_context = run_context
self.step_event.set()
self.step_event.clear()
def ds_step_begin(self, ds_run_context):
"""
Internal method, do not call/override. Defines ds_step_begin of DSCallback to wait for MS step_end callback.
Args:
ds_run_context: Include some information of the pipeline.
"""
if ds_run_context.cur_step_num > self.step_size:
if self.step_run_context is None:
self.step_event.wait()
self.sync_step_begin(self.step_run_context, ds_run_context)
self.step_run_context = None
def create_runtime_obj(self):
"""
Creates a runtime (C++) object from the callback methods defined by the user. This method is internal.
Returns: _c_dataengine.PyDSCallback
"""
c_cb = PyDSCallback(self.step_size)
at_least_one = False
if self.__class__.sync_step_begin != WaitedDSCallback.sync_step_begin:
c_cb.set_step_begin(self.ds_step_begin)
at_least_one = True
if self.__class__.sync_epoch_begin != WaitedDSCallback.sync_epoch_begin:
c_cb.set_epoch_begin(self.ds_epoch_begin)
at_least_one = True
if not at_least_one:
raise AttributeError("Provided Callback class did not override any of the 2 callback methods.")
return c_cb

@ -0,0 +1,34 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License foNtest_resr the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Built-in validators.
"""
from functools import wraps
from ..core.validator_helpers import parse_user_args, check_pos_int32
def check_callback(method):
"""check the input arguments of DSCallback."""
@wraps(method)
def new_method(self, *args, **kwargs):
[step_size], _ = parse_user_args(method, *args, **kwargs)
check_pos_int32(step_size, "step_size")
return method(self, *args, **kwargs)
return new_method

@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset,\
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, \
check_paddeddataset
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
from ..text.utils import DE_C_INTER_SENTENCEPIECE_MODE
@ -395,7 +395,7 @@ class Dataset:
@check_map
def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None,
num_parallel_workers=None, python_multiprocessing=False, cache=None):
num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None):
"""
Apply each operation in operations to this dataset.
@ -438,6 +438,8 @@ class Dataset:
option could be beneficial if the python operation is computational heavy (default=False).
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
The cache feature is under development and is not recommended.
callbacks: (DSCallback, list[DSCallback], optional): list of Dataset callbacks to be called (Default=None).
Returns:
MapDataset, dataset after mapping operation.
@ -552,7 +554,7 @@ class Dataset:
>>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns, columns_order)
"""
return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers,
python_multiprocessing, cache)
python_multiprocessing, cache, callbacks)
@check_filter
def filter(self, predicate, input_columns=None, num_parallel_workers=1):
@ -1548,6 +1550,7 @@ class DatasetOp(Dataset):
return self.children[0].get_class_indexing()
raise NotImplementedError("Dataset {} has not supported api get_class_indexing yet.".format(type(self)))
class BucketBatchByLengthDataset(DatasetOp):
"""
The result of applying BucketBatchByLength operator to the input dataset.
@ -1964,14 +1967,14 @@ class MapDataset(DatasetOp):
option could be beneficial if the python operation is computational heavy (default=False).
cache (DatasetCache, optional): Tensor cache to use. (default=None which means no cache is used).
The cache feature is under development and is not recommended.
callbacks: (DSCallback, list[DSCallback], optional): list of Dataset callbacks to be called (Default=None)
Raises:
ValueError: If len(input_columns) != len(output_columns) and columns_order is not specified.
"""
def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None,
num_parallel_workers=None, python_multiprocessing=False, cache=None):
num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None):
super().__init__(num_parallel_workers)
self.children.append(input_dataset)
if input_columns is not None and not isinstance(input_columns, list):
@ -1996,6 +1999,11 @@ class MapDataset(DatasetOp):
self.python_multiprocessing = python_multiprocessing
self.process_pool = None
if callbacks is not None and not isinstance(callbacks, list):
callbacks = [callbacks]
self.callbacks = callbacks
def get_args(self):
args = super().get_args()
args["input_columns"] = self.input_columns
@ -2003,6 +2011,9 @@ class MapDataset(DatasetOp):
args["output_columns"] = self.output_columns
args["columns_order"] = self.columns_order
args["cache"] = self.cache.cache_client if self.cache is not None else None
if self.callbacks is not None:
args["callbacks"] = [cb.create_runtime_obj() for cb in self.callbacks]
return args
def get_dataset_size(self):
@ -2034,6 +2045,7 @@ class MapDataset(DatasetOp):
new_op.cache = copy.deepcopy(self.cache, memodict)
new_op.operations = self.operations
new_op.dataset_size = self.dataset_size
new_op.callbacks = self.callbacks
return new_op
# Iterator bootstrap will be called on iterator construction.
@ -2393,7 +2405,6 @@ class ConcatDataset(DatasetOp):
self._children_start_end_index_[index][0] = cumulative_samples_nums
self._children_start_end_index_[index][1] = tem_value % sampler.num_shards
tem_sampler = copy.deepcopy(sampler)
tem_sampler.set_offset(cumulative_samples_nums)
child.sampler = tem_sampler
@ -2556,7 +2567,7 @@ class RangeDataset(MappableDataset):
def get_dataset_size(self):
if self.dataset_size is None:
self.dataset_size = math.ceil((self.stop - self.start)/self.step)
self.dataset_size = math.ceil((self.stop - self.start) / self.step)
return self.dataset_size
@ -3423,7 +3434,7 @@ class GeneratorDataset(MappableDataset):
if not self.num_shards:
self.dataset_size = len(self.source)
else:
self.dataset_size = math.ceil(len(self.source)/self.num_shards)
self.dataset_size = math.ceil(len(self.source) / self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
@ -5428,6 +5439,7 @@ class NumpySlicesDataset(GeneratorDataset):
num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler,
num_shards=num_shards, shard_id=shard_id)
class _PaddedDataset:
"""
Mainly for combining false samples provided by users into a dataset.
@ -5435,6 +5447,7 @@ class _PaddedDataset:
Args:
padded_samples (list(dict)): the data provided by user to added to initial Dataset
"""
def __init__(self, padded_samples):
self.column_names = list(padded_samples[0].keys())
self.padded_samples = padded_samples
@ -5445,6 +5458,7 @@ class _PaddedDataset:
def __len__(self):
return len(self.padded_samples)
class PaddedDataset(GeneratorDataset):
"""
Create a dataset with fake data provided by user. Mainly used to add to the original data set
@ -5463,6 +5477,7 @@ class PaddedDataset(GeneratorDataset):
>>> data1 = [{'image': np.zeros(1, np.uint8)}, {'image': np.zeros(2, np.uint8)}]
>>> ds1 = ds.PaddedDataset(data1)
"""
@check_paddeddataset
def __init__(self, padded_samples):
dataset = _PaddedDataset(padded_samples)

@ -23,6 +23,7 @@ from functools import wraps
import numpy as np
from mindspore._c_expression import typing
from mindspore.dataset.callback import DSCallback
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \
INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \
validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \
@ -31,6 +32,7 @@ from ..core.validator_helpers import parse_user_args, type_check, type_check_lis
from . import datasets
from . import samplers
from . import cache_client
from .. import callback
def check_imagefolderdatasetv2(method):
@ -247,6 +249,7 @@ def check_celebadataset(method):
return new_method
def check_save(method):
"""A wrapper that wrap a parameter checker to the save op."""
@ -257,7 +260,7 @@ def check_save(method):
nreq_param_int = ['num_files']
nreq_param_str = ['file_name', 'file_type']
validate_dataset_param_value(nreq_param_int, param_dict, int)
if(param_dict.get('num_files') <= 0 or param_dict.get('num_files') > 1000):
if (param_dict.get('num_files') <= 0 or param_dict.get('num_files') > 1000):
raise ValueError("num_files should between {} and {}.".format(1, 1000))
validate_dataset_param_value(nreq_param_str, param_dict, str)
if param_dict.get('file_type') != 'mindrecord':
@ -265,6 +268,8 @@ def check_save(method):
return method(self, *args, **kwargs)
return new_method
def check_minddataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(MindDataset)."""
@ -362,6 +367,7 @@ def check_generatordataset(method):
return new_method
def check_random_dataset(method):
"""A wrapper that wraps a parameter checker to the original Dataset(RandomDataset)."""
@ -545,7 +551,8 @@ def check_map(method):
@wraps(method)
def new_method(self, *args, **kwargs):
[input_columns, _, output_columns, columns_order, num_parallel_workers, python_multiprocessing, cache], _ = \
[input_columns, _, output_columns, columns_order, num_parallel_workers, python_multiprocessing, cache,
callbacks], _ = \
parse_user_args(method, *args, **kwargs)
nreq_param_columns = ['input_columns', 'output_columns']
@ -558,9 +565,17 @@ def check_map(method):
if cache is not None:
type_check(cache, (cache_client.DatasetCache,), "cache")
if callbacks is not None:
if isinstance(callbacks, (list, tuple)):
type_check_list(callbacks, (callback.DSCallback,), "callbacks")
else:
type_check(callbacks, (callback.DSCallback,), "callbacks")
for param_name, param in zip(nreq_param_columns, [input_columns, output_columns]):
if param is not None:
check_columns(param, param_name)
if callbacks is not None:
type_check(callbacks, (list, DSCallback), "callbacks")
return method(self, *args, **kwargs)

@ -15,6 +15,7 @@ SET(DE_UT_SRCS
bounding_box_augment_op_test.cc
arena_test.cc
btree_test.cc
callback_test.cc
center_crop_op_test.cc
channel_swap_test.cc
circular_pool_test.cc

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save