implment pause in MapOp, added more to callback add ds_callback - Initial drop of Python DSCallback - Pybind DSCallback - Pybind DSCallback added callback to mapOp - de_pipeline DSCallback - de_pipeline DSCallback add test case, segfault for now fix seg fault - de_pipeline DSCallback remove 1 line update callback test case, now works use builder class for mapOp callback - de_pipeline DSCallback - de_pipeline DSCallback - de_pipeline DSCallback better test case minor fix add comments and minor clean ups get rid of nullptr in MapOp, use other flag instead fix a bug ParseMapOp only takes 1 callback - Added WaitedDSCalabck refactor callback param fix text case incorrect number - added testing fix cpp test case - added testing - revert back lenet changes - cleanup test_callbacks.py - cleanup test_callbacks.py fix CI stage I fix CI stage II fix CI and update epoch counter - add validation - add more testing test_callbacks.py use random data op to do tests adjust when to call EpochBegin/End - add repeat with callback - addressing reviewers' comments - docstring and CI fixes - docstring and CI fixes - docstring and CI fixes - rebase with upstream/master fix cpp test case fix review comments addr review cmts, add test casepull/4632/head
parent
89cd465268
commit
78c1aa1d96
@ -0,0 +1,45 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl_bind.h"
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/callback/py_ds_callback.h"
|
||||
#include "minddata/dataset/callback/ds_callback.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
PYBIND_REGISTER(PyDSCallback, 0, ([](const py::module *m) {
|
||||
(void)py::class_<PyDSCallback, std::shared_ptr<PyDSCallback>>(*m, "PyDSCallback")
|
||||
.def(py::init<int32_t>())
|
||||
.def("set_begin", &PyDSCallback::setBegin)
|
||||
.def("set_end", &PyDSCallback::setEnd)
|
||||
.def("set_epoch_begin", &PyDSCallback::setEpochBegin)
|
||||
.def("set_epoch_end", &PyDSCallback::setEpochEnd)
|
||||
.def("set_step_begin", &PyDSCallback::setStepBegin)
|
||||
.def("set_step_end", &PyDSCallback::setStepEnd);
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(CallbackParam, 0, ([](const py::module *m) {
|
||||
(void)py::class_<CallbackParam, std::shared_ptr<CallbackParam>>(*m, "CallbackParam")
|
||||
.def(py::init<int64_t, int64_t, int64_t>())
|
||||
.def_readonly("cur_epoch_num", &CallbackParam::cur_epoch_num_)
|
||||
.def_readonly("cur_step_num_in_epoch", &CallbackParam::cur_epoch_step_num_)
|
||||
.def_readonly("cur_step_num", &CallbackParam::cur_step_num_);
|
||||
}));
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,14 @@
|
||||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
|
||||
|
||||
if (ENABLE_PYTHON)
|
||||
add_library(callback OBJECT
|
||||
callback_manager.cc
|
||||
py_ds_callback.cc
|
||||
)
|
||||
else ()
|
||||
add_library(callback OBJECT
|
||||
callback_manager.cc
|
||||
)
|
||||
endif ()
|
@ -0,0 +1,160 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/callback/callback_manager.h"
|
||||
#include "minddata/dataset/callback/ds_callback.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/dataset/engine/datasetops/dataset_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
void CallbackManager::AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks) {
|
||||
callbacks_.insert(callbacks_.end(), callbacks.begin(), callbacks.end());
|
||||
}
|
||||
|
||||
Status CallbackManager::Init(std::shared_ptr<DatasetOp> op) {
|
||||
RETURN_UNEXPECTED_IF_NULL(op);
|
||||
op_ = op;
|
||||
// turn the flag on if callback is set
|
||||
enabled_ = !callbacks_.empty();
|
||||
|
||||
// error check for each of the callbacks
|
||||
for (auto &cb : callbacks_) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(cb->step_size() > 0, "callback step_size needs to be greater than 0.");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CallbackManager::Begin(const CallbackParam &cb_param) {
|
||||
RETURN_OK_IF_TRUE(!enabled_);
|
||||
std::vector<size_t> callback_inds;
|
||||
// go through all callback functions to see if each function is needed
|
||||
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||
if (callbacks_[ind]->IsBeginNeeded()) callback_inds.push_back(ind);
|
||||
}
|
||||
// return Status::OK() if no begin is needed
|
||||
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||
|
||||
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||
|
||||
// Now do the actual callback
|
||||
for (size_t ind : callback_inds) {
|
||||
RETURN_IF_NOT_OK(callbacks_[ind]->DSBegin(cb_param));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CallbackManager::EpochBegin(const CallbackParam &cb_param) {
|
||||
RETURN_OK_IF_TRUE(!enabled_);
|
||||
std::vector<size_t> callback_inds;
|
||||
// go through all callback functions to see if each function is needed
|
||||
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||
if (callbacks_[ind]->IsEpochBeginNeeded()) callback_inds.push_back(ind);
|
||||
}
|
||||
// return Status::OK() if no epoch_begin is needed
|
||||
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||
|
||||
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||
|
||||
// Now do the actual callback
|
||||
for (size_t ind : callback_inds) {
|
||||
RETURN_IF_NOT_OK(callbacks_[ind]->DSEpochBegin(cb_param));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CallbackManager::StepBegin(const CallbackParam &cb_param) {
|
||||
RETURN_OK_IF_TRUE(!enabled_);
|
||||
std::vector<size_t> callback_inds;
|
||||
// go through all callback functions to see if each function is needed
|
||||
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||
if (callbacks_[ind]->IsNStepBeginNeeded() && (cb_param.cur_epoch_step_num_ - 1) % callbacks_[ind]->step_size() == 0)
|
||||
callback_inds.push_back(ind);
|
||||
}
|
||||
// return Status::OK() if no step_begin is needed
|
||||
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||
|
||||
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||
|
||||
// Now do the actual callback
|
||||
for (size_t ind : callback_inds) {
|
||||
RETURN_IF_NOT_OK(callbacks_[ind]->DSNStepBegin(cb_param));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CallbackManager::End(const CallbackParam &cb_param) {
|
||||
RETURN_OK_IF_TRUE(!enabled_);
|
||||
std::vector<size_t> callback_inds;
|
||||
// go through all callback functions to see if each function is needed
|
||||
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||
if (callbacks_[ind]->IsEndNeeded()) callback_inds.push_back(ind);
|
||||
}
|
||||
// return Status::OK() if no end is needed
|
||||
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||
|
||||
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||
|
||||
// Now do the actual callback
|
||||
for (size_t ind : callback_inds) {
|
||||
RETURN_IF_NOT_OK(callbacks_[ind]->DSEnd(cb_param));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CallbackManager::EpochEnd(const CallbackParam &cb_param) {
|
||||
RETURN_OK_IF_TRUE(!enabled_);
|
||||
std::vector<size_t> callback_inds;
|
||||
// go through all callback functions to see if each function is needed
|
||||
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||
if (callbacks_[ind]->IsEpochEndNeeded()) callback_inds.push_back(ind);
|
||||
}
|
||||
// return Status::OK() if no epoch_end is needed
|
||||
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||
|
||||
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||
|
||||
// Now do the actual callback
|
||||
for (size_t ind : callback_inds) {
|
||||
RETURN_IF_NOT_OK(callbacks_[ind]->DSEpochEnd(cb_param));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CallbackManager::StepEnd(const CallbackParam &cb_param) {
|
||||
RETURN_OK_IF_TRUE(!enabled_);
|
||||
std::vector<size_t> callback_inds;
|
||||
// go through all callback functions to see if each function is needed
|
||||
for (size_t ind = 0; ind < callbacks_.size(); ind++) {
|
||||
if (callbacks_[ind]->IsNStepEndNeeded() && (cb_param.cur_epoch_step_num_ - 1) % callbacks_[ind]->step_size() == 0)
|
||||
callback_inds.push_back(ind);
|
||||
}
|
||||
// return Status::OK() if no step_end is needed
|
||||
RETURN_OK_IF_TRUE(callback_inds.empty());
|
||||
|
||||
RETURN_IF_NOT_OK(op_->PauseFromMaster());
|
||||
|
||||
// Now do the actual callback
|
||||
for (size_t ind : callback_inds) {
|
||||
RETURN_IF_NOT_OK(callbacks_[ind]->DSNStepEnd(cb_param));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,79 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_MANAGER_H
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_MANAGER_H
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/callback/ds_callback.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// forward declare to avoid cyclic include of dataset_op.h
|
||||
class DatasetOp;
|
||||
|
||||
/// This class manages all the callbacks that are associated with a single DatasetOp. For now, only MapOp supports this.
|
||||
class CallbackManager {
|
||||
public:
|
||||
/// CallbackManager default constructor. Init needs to be called before using the created instance.
|
||||
CallbackManager() : enabled_(false) {}
|
||||
|
||||
/// \brief
|
||||
/// \param [in] callbacks list of callbacks to perform
|
||||
void AddCallbacks(std::vector<std::shared_ptr<DSCallback>> callbacks);
|
||||
|
||||
/// \brief DatasetOp needs to call Init if it wishes to use callback, Init will set enabled_ to true
|
||||
/// \param[in] op, this pointer is used for Callback Manager to Pause Worker threads
|
||||
/// \return Status
|
||||
Status Init(std::shared_ptr<DatasetOp> op);
|
||||
|
||||
/// \brief callback function called at the start of the first row
|
||||
/// \return Status
|
||||
Status Begin(const CallbackParam &);
|
||||
|
||||
/// \brief callback function called at the start of each epoch
|
||||
/// \return Status
|
||||
Status EpochBegin(const CallbackParam &);
|
||||
|
||||
/// \brief callback function called at the start of each row
|
||||
/// \return Status
|
||||
Status StepBegin(const CallbackParam &);
|
||||
|
||||
/// \brief callback function called after the last row is processed
|
||||
/// \return Status
|
||||
Status End(const CallbackParam &);
|
||||
|
||||
/// \brief callback function called at the end of each epoch
|
||||
/// \return Status
|
||||
Status EpochEnd(const CallbackParam &);
|
||||
|
||||
/// \brief callback function called at the the end of each row
|
||||
/// \return Status
|
||||
Status StepEnd(const CallbackParam &);
|
||||
|
||||
private:
|
||||
bool enabled_; // flag to enable callback, if false, all functions would return immediately
|
||||
std::shared_ptr<DatasetOp> op_; // back pointer to DatasetOp, each DatasetOp has only 1 CallbackManager
|
||||
std::vector<std::shared_ptr<DSCallback>> callbacks_; // list of callbacks the DatasetOp needs to call
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_MANAGER_H
|
@ -0,0 +1,40 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_PARAM_H
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_PARAM_H
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
/// Callback Param is the object a DatasetOp uses to pass run-time information to user defined function.
|
||||
/// This is a prototype for now, more fields will be added
|
||||
class CallbackParam {
|
||||
public:
|
||||
CallbackParam(int64_t epoch_num, int64_t cur_epoch_step, int64_t total_step_num)
|
||||
: cur_epoch_num_(epoch_num), cur_epoch_step_num_(cur_epoch_step), cur_step_num_(total_step_num) {}
|
||||
|
||||
// these are constant public fields for easy access and consistency with python cb_param
|
||||
// the names and orders are consistent with batchInfo
|
||||
const int64_t cur_epoch_num_; // current epoch
|
||||
const int64_t cur_epoch_step_num_; // step number of the current epoch
|
||||
const int64_t cur_step_num_; // step number since the first row
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CALLBACK_PARAM_H
|
@ -0,0 +1,100 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_DS_CALLBACK_H
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_DS_CALLBACK_H
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/callback/callback_param.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class DSCallback {
|
||||
public:
|
||||
/// \brief constructor of DSCallback, this is the base class for all front end specific callbacks
|
||||
/// \param step_size number of steps to call DSNStepBegin()
|
||||
explicit DSCallback(int32_t step_size = 1) : step_size_(step_size) {}
|
||||
|
||||
/// \brief actual callback function for begin, needs to be overridden in the derived class
|
||||
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||
/// \return Status
|
||||
virtual Status DSBegin(const CallbackParam &cb_param) = 0;
|
||||
|
||||
/// \brief actual callback function for epoch_begin, needs to be overridden in the derived class
|
||||
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||
/// \return Status
|
||||
virtual Status DSEpochBegin(const CallbackParam &cb_param) = 0;
|
||||
|
||||
/// \brief actual callback function for step_begin, needs to be overridden in the derived class
|
||||
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||
/// \return Status
|
||||
virtual Status DSNStepBegin(const CallbackParam &cb_param) = 0;
|
||||
|
||||
/// \brief actual callback function for end, needs to be overridden in the derived class
|
||||
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||
/// \return Status
|
||||
virtual Status DSEnd(const CallbackParam &cb_param) = 0;
|
||||
|
||||
/// \brief actual callback function epoch_end begin, needs to be overridden in the derived class
|
||||
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||
/// \return Status
|
||||
virtual Status DSEpochEnd(const CallbackParam &cb_param) = 0;
|
||||
|
||||
/// \brief actual callback function for step_end, needs to be overridden in the derived class
|
||||
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||
/// \return Status
|
||||
virtual Status DSNStepEnd(const CallbackParam &cb_param) = 0;
|
||||
|
||||
/// \brief predicate function, whether begin callback is needed
|
||||
/// \return bool
|
||||
virtual bool IsBeginNeeded() = 0;
|
||||
|
||||
/// \brief predicate function, whether epoch_begin callback is needed
|
||||
/// \return bool
|
||||
virtual bool IsEpochBeginNeeded() = 0;
|
||||
|
||||
/// \brief predicate function, whether step_begin callback is needed
|
||||
/// \return bool
|
||||
virtual bool IsNStepBeginNeeded() = 0;
|
||||
|
||||
/// \brief predicate function, whether end callback is needed
|
||||
/// \return bool
|
||||
virtual bool IsEndNeeded() = 0;
|
||||
|
||||
/// \brief predicate function, whether epoch_end callback is needed
|
||||
/// \return bool
|
||||
virtual bool IsEpochEndNeeded() = 0;
|
||||
|
||||
/// \brief predicate function, whether step_end callback is needed
|
||||
/// \return bool
|
||||
virtual bool IsNStepEndNeeded() = 0;
|
||||
|
||||
/// \brief getter
|
||||
/// \return step_size
|
||||
int32_t step_size() const { return step_size_; }
|
||||
|
||||
protected:
|
||||
int32_t step_size_; // step begin/end will be called every step_size_
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_DS_CALLBACK_H
|
@ -0,0 +1,86 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/callback/callback_manager.h"
|
||||
#include "minddata/dataset/callback/py_ds_callback.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
Status PyDSCallback::DSBegin(const CallbackParam &cb_param) {
|
||||
return PyDSCallback::ExecutePyfunc(begin_func_, cb_param);
|
||||
}
|
||||
Status PyDSCallback::DSEpochBegin(const CallbackParam &cb_param) {
|
||||
return PyDSCallback::ExecutePyfunc(epoch_begin_func_, cb_param);
|
||||
}
|
||||
Status PyDSCallback::DSNStepBegin(const CallbackParam &cb_param) {
|
||||
return PyDSCallback::ExecutePyfunc(step_begin_func_, cb_param);
|
||||
}
|
||||
Status PyDSCallback::DSEnd(const CallbackParam &cb_param) { return PyDSCallback::ExecutePyfunc(end_func_, cb_param); }
|
||||
|
||||
Status PyDSCallback::DSEpochEnd(const CallbackParam &cb_param) {
|
||||
return PyDSCallback::ExecutePyfunc(epoch_end_func_, cb_param);
|
||||
}
|
||||
Status PyDSCallback::DSNStepEnd(const CallbackParam &cb_param) {
|
||||
return PyDSCallback::ExecutePyfunc(step_end_func_, cb_param);
|
||||
}
|
||||
|
||||
bool PyDSCallback::IsBeginNeeded() { return begin_needed_; }
|
||||
bool PyDSCallback::IsEpochBeginNeeded() { return epoch_begin_needed_; }
|
||||
bool PyDSCallback::IsNStepBeginNeeded() { return step_begin_needed_; }
|
||||
bool PyDSCallback::IsNStepEndNeeded() { return step_end_needed_; }
|
||||
bool PyDSCallback::IsEpochEndNeeded() { return epoch_end_needed_; }
|
||||
bool PyDSCallback::IsEndNeeded() { return end_needed_; }
|
||||
|
||||
Status PyDSCallback::ExecutePyfunc(py::function f, const CallbackParam &cb_param) {
|
||||
{
|
||||
// Acquire Python GIL
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
if (Py_IsInitialized() == 0) {
|
||||
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
|
||||
}
|
||||
f(cb_param);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
void PyDSCallback::setBegin(py::function f) {
|
||||
begin_func_ = f;
|
||||
begin_needed_ = true;
|
||||
}
|
||||
void PyDSCallback::setEnd(py::function f) {
|
||||
end_func_ = f;
|
||||
end_needed_ = true;
|
||||
}
|
||||
void PyDSCallback::setEpochBegin(py::function f) {
|
||||
epoch_begin_func_ = f;
|
||||
epoch_begin_needed_ = true;
|
||||
}
|
||||
void PyDSCallback::setEpochEnd(py::function f) {
|
||||
epoch_end_func_ = f;
|
||||
epoch_end_needed_ = true;
|
||||
}
|
||||
void PyDSCallback::setStepBegin(py::function f) {
|
||||
step_begin_func_ = f;
|
||||
step_begin_needed_ = true;
|
||||
}
|
||||
void PyDSCallback::setStepEnd(py::function f) {
|
||||
step_end_func_ = f;
|
||||
step_end_needed_ = true;
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,130 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_PY_DS_CALLBACK_H
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_PY_DS_CALLBACK_H
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/callback/ds_callback.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
class PyDSCallback : public DSCallback {
|
||||
public:
|
||||
/// \brief constructor for PyDSCallback. This callback is for python front end
|
||||
explicit PyDSCallback(int32_t step_size = 1)
|
||||
: DSCallback(step_size),
|
||||
begin_needed_(false),
|
||||
epoch_begin_needed_(false),
|
||||
step_begin_needed_(false),
|
||||
end_needed_(false),
|
||||
epoch_end_needed_(false),
|
||||
step_end_needed_(false) {}
|
||||
|
||||
void setBegin(py::function f);
|
||||
void setEnd(py::function f);
|
||||
void setEpochBegin(py::function f);
|
||||
void setEpochEnd(py::function f);
|
||||
void setStepBegin(py::function f);
|
||||
void setStepEnd(py::function f);
|
||||
|
||||
/// \brief actual callback function for begin, needs to be overridden in the derived class
|
||||
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||
/// \return Status
|
||||
Status DSBegin(const CallbackParam &cb_param) override;
|
||||
|
||||
/// \brief actual callback function for epoch_begin, needs to be overridden in the derived class
|
||||
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||
/// \return Status
|
||||
Status DSEpochBegin(const CallbackParam &cb_param) override;
|
||||
|
||||
/// \brief actual callback function for step_begin, needs to be overridden in the derived class
|
||||
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||
/// \return Status
|
||||
Status DSNStepBegin(const CallbackParam &cb_param) override;
|
||||
|
||||
/// \brief actual callback function for end, needs to be overridden in the derived class
|
||||
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||
/// \return Status
|
||||
Status DSEnd(const CallbackParam &cb_param) override;
|
||||
|
||||
/// \brief actual callback function epoch_end begin, needs to be overridden in the derived class
|
||||
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||
/// \return Status
|
||||
Status DSEpochEnd(const CallbackParam &cb_param) override;
|
||||
|
||||
/// \brief actual callback function for step_end, needs to be overridden in the derived class
|
||||
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||
/// \return Status
|
||||
Status DSNStepEnd(const CallbackParam &cb_param) override;
|
||||
|
||||
/// \brief predicate function, whether begin callback is needed
|
||||
/// \return bool
|
||||
bool IsBeginNeeded() override;
|
||||
|
||||
/// \brief predicate function, whether epoch_begin callback is needed
|
||||
/// \return bool
|
||||
bool IsEpochBeginNeeded() override;
|
||||
|
||||
/// \brief predicate function, whether step_begin callback is needed
|
||||
/// \return bool
|
||||
bool IsNStepBeginNeeded() override;
|
||||
|
||||
/// \brief predicate function, whether end callback is needed
|
||||
/// \return bool
|
||||
bool IsEndNeeded() override;
|
||||
|
||||
/// \brief predicate function, whether epoch_end callback is needed
|
||||
/// \return bool
|
||||
bool IsEpochEndNeeded() override;
|
||||
|
||||
/// \brief predicate function, whether step_end callback is needed
|
||||
/// \return bool
|
||||
bool IsNStepEndNeeded() override;
|
||||
|
||||
/// \brief helper function to acquire GIL then execute a pyfunc
|
||||
/// \param f the python function
|
||||
/// \param cb_param
|
||||
/// \return Status
|
||||
static Status ExecutePyfunc(py::function f, const CallbackParam &cb_param);
|
||||
|
||||
private:
|
||||
py::function begin_func_;
|
||||
py::function epoch_begin_func_;
|
||||
py::function step_begin_func_;
|
||||
py::function end_func_;
|
||||
py::function epoch_end_func_;
|
||||
py::function step_end_func_;
|
||||
|
||||
bool begin_needed_;
|
||||
bool epoch_begin_needed_;
|
||||
bool step_begin_needed_;
|
||||
bool end_needed_;
|
||||
bool epoch_end_needed_;
|
||||
bool step_end_needed_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_PY_DS_CALLBACK_H
|
@ -0,0 +1,18 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""init file for python callback"""
|
||||
from .ds_callback import DSCallback, WaitedDSCallback
|
||||
|
||||
__all__ = ["DSCallback", "WaitedDSCallback"]
|
@ -0,0 +1,232 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Python callback class
|
||||
"""
|
||||
import threading
|
||||
from mindspore._c_dataengine import PyDSCallback
|
||||
from mindspore.train.callback import Callback
|
||||
from .validators import check_callback
|
||||
|
||||
|
||||
class DSCallback:
|
||||
"""
|
||||
Abstract base class used to build a dataset callback class.
|
||||
|
||||
Args:
|
||||
step_size (int, optional): The number of steps before the step_begin and step_end are called (Default=1).
|
||||
|
||||
Examples:
|
||||
>>> class PrintInfo(DSCallback):
|
||||
>>> def ds_epoch_end(self, ds_run_context):
|
||||
>>> print(cb_params.cur_epoch_num)
|
||||
>>> print(cb_params.cur_step_num)
|
||||
>>>
|
||||
>>> data = data.map(operations=op, callbacks=PrintInfo())
|
||||
"""
|
||||
|
||||
@check_callback
|
||||
def __init__(self, step_size=1):
|
||||
self.step_size = step_size
|
||||
|
||||
def ds_begin(self, ds_run_context):
|
||||
"""
|
||||
Called before the data pipeline is started.
|
||||
|
||||
Args:
|
||||
ds_run_context (RunContext): Include some information of the pipeline.
|
||||
"""
|
||||
|
||||
def ds_epoch_begin(self, ds_run_context):
|
||||
"""
|
||||
Called before a new epoch is started.
|
||||
|
||||
Args:
|
||||
ds_run_context (RunContext): Include some information of the pipeline.
|
||||
"""
|
||||
|
||||
def ds_epoch_end(self, ds_run_context):
|
||||
"""
|
||||
Called after an epoch is finished.
|
||||
|
||||
Args:
|
||||
ds_run_context (RunContext): Include some information of the pipeline.
|
||||
"""
|
||||
|
||||
def ds_step_begin(self, ds_run_context):
|
||||
"""
|
||||
Called before n steps are started.
|
||||
|
||||
Args:
|
||||
ds_run_context (RunContext): Include some information of the pipeline.
|
||||
"""
|
||||
|
||||
def ds_step_end(self, ds_run_context):
|
||||
"""
|
||||
Called after n steps are finished.
|
||||
|
||||
Args:
|
||||
ds_run_context (RunContext): Include some information of the pipeline.
|
||||
"""
|
||||
|
||||
def create_runtime_obj(self):
|
||||
"""
|
||||
Creates a runtime (C++) object from the callback methods defined by the user.
|
||||
|
||||
Returns: _c_dataengine.PyDSCallback
|
||||
"""
|
||||
c_cb = PyDSCallback(self.step_size)
|
||||
at_least_one = False
|
||||
|
||||
if self.__class__.ds_begin != DSCallback.ds_begin:
|
||||
c_cb.set_begin(self.ds_begin)
|
||||
at_least_one = True
|
||||
|
||||
if self.__class__.ds_epoch_begin != DSCallback.ds_epoch_begin:
|
||||
c_cb.set_epoch_begin(self.ds_epoch_begin)
|
||||
at_least_one = True
|
||||
if self.__class__.ds_epoch_end != DSCallback.ds_epoch_end:
|
||||
c_cb.set_epoch_end(self.ds_epoch_end)
|
||||
at_least_one = True
|
||||
|
||||
if self.__class__.ds_step_begin != DSCallback.ds_step_begin:
|
||||
c_cb.set_step_begin(self.ds_step_begin)
|
||||
at_least_one = True
|
||||
if self.__class__.ds_step_end != DSCallback.ds_step_end:
|
||||
c_cb.set_step_end(self.ds_step_end)
|
||||
at_least_one = True
|
||||
|
||||
if not at_least_one:
|
||||
raise AttributeError("Provided Callback class did not override any of the 6 callback methods.")
|
||||
|
||||
return c_cb
|
||||
|
||||
|
||||
class WaitedDSCallback(Callback, DSCallback):
|
||||
"""
|
||||
Abstract base class used to build a dataset callback class that are synchronized with the training callback.
|
||||
|
||||
This class can be used to execute a user defined logic right after the previous step or epoch.
|
||||
For example, one augmentation needs the loss from the previous trained epoch to update some of its parameters.
|
||||
|
||||
Examples:
|
||||
>>> my_cb = MyWaitedCallback(32)
|
||||
>>> data = data.map(operations=AugOp(), callbacks=my_cb)
|
||||
>>> data = data.batch(32)
|
||||
>>> # define the model
|
||||
>>> model.train(epochs, data, callbacks=[my_cb])
|
||||
|
||||
|
||||
Args:
|
||||
step_size: the number of rows in each step.
|
||||
Usually the step size will be equal to the batch size (Default=1)
|
||||
"""
|
||||
|
||||
def __init__(self, step_size=1):
|
||||
super().__init__()
|
||||
self.step_size = step_size
|
||||
self.step_event = threading.Event()
|
||||
self.step_run_context = None
|
||||
|
||||
self.epoch_event = threading.Event()
|
||||
self.epoch_run_context = None
|
||||
|
||||
def sync_epoch_begin(self, train_run_context, ds_run_context):
|
||||
"""
|
||||
Called before a new dataset epoch is started and after the previous training epoch is ended.
|
||||
|
||||
Args:
|
||||
train_run_context: Include some information of the model with feedback from the previous epoch.
|
||||
ds_run_context: Include some information of the dataset pipeline.
|
||||
"""
|
||||
|
||||
def sync_step_begin(self, train_run_context, ds_run_context):
|
||||
"""
|
||||
Called before a new dataset step is started and after the previous training step is ended.
|
||||
|
||||
Args:
|
||||
train_run_context: Include some information of the model with feedback from the previous step.
|
||||
ds_run_context: Include some information of the dataset pipeline.
|
||||
"""
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
"""
|
||||
Internal method, do not call/override. Defines epoch_end of Callback to release the wait in ds_epoch_begin.
|
||||
|
||||
Args:
|
||||
run_context: Include some information of the model.
|
||||
"""
|
||||
self.epoch_run_context = run_context
|
||||
self.epoch_event.set()
|
||||
self.epoch_event.clear()
|
||||
|
||||
def ds_epoch_begin(self, ds_run_context):
|
||||
"""
|
||||
Internal method, do not call/override. Defines ds_epoch_begin of DSCallback to wait for MS epoch_end callback.
|
||||
|
||||
Args:
|
||||
ds_run_context: Include some information of the pipeline.
|
||||
"""
|
||||
if ds_run_context.cur_epoch_num > 1:
|
||||
if self.epoch_run_context is None:
|
||||
self.epoch_event.wait()
|
||||
self.sync_epoch_begin(self.epoch_run_context, ds_run_context)
|
||||
self.epoch_run_context = None
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""
|
||||
Internal method, do not call/override. Defines step_end of Callback to release the wait in ds_step_begin.
|
||||
|
||||
Args:
|
||||
run_context: Include some information of the model.
|
||||
"""
|
||||
self.step_run_context = run_context
|
||||
self.step_event.set()
|
||||
self.step_event.clear()
|
||||
|
||||
def ds_step_begin(self, ds_run_context):
|
||||
"""
|
||||
Internal method, do not call/override. Defines ds_step_begin of DSCallback to wait for MS step_end callback.
|
||||
|
||||
Args:
|
||||
ds_run_context: Include some information of the pipeline.
|
||||
"""
|
||||
if ds_run_context.cur_step_num > self.step_size:
|
||||
if self.step_run_context is None:
|
||||
self.step_event.wait()
|
||||
self.sync_step_begin(self.step_run_context, ds_run_context)
|
||||
self.step_run_context = None
|
||||
|
||||
def create_runtime_obj(self):
|
||||
"""
|
||||
Creates a runtime (C++) object from the callback methods defined by the user. This method is internal.
|
||||
|
||||
Returns: _c_dataengine.PyDSCallback
|
||||
"""
|
||||
c_cb = PyDSCallback(self.step_size)
|
||||
at_least_one = False
|
||||
|
||||
if self.__class__.sync_step_begin != WaitedDSCallback.sync_step_begin:
|
||||
c_cb.set_step_begin(self.ds_step_begin)
|
||||
at_least_one = True
|
||||
|
||||
if self.__class__.sync_epoch_begin != WaitedDSCallback.sync_epoch_begin:
|
||||
c_cb.set_epoch_begin(self.ds_epoch_begin)
|
||||
at_least_one = True
|
||||
|
||||
if not at_least_one:
|
||||
raise AttributeError("Provided Callback class did not override any of the 2 callback methods.")
|
||||
|
||||
return c_cb
|
@ -0,0 +1,34 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License foNtest_resr the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""
|
||||
Built-in validators.
|
||||
"""
|
||||
|
||||
from functools import wraps
|
||||
|
||||
from ..core.validator_helpers import parse_user_args, check_pos_int32
|
||||
|
||||
|
||||
def check_callback(method):
|
||||
"""check the input arguments of DSCallback."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[step_size], _ = parse_user_args(method, *args, **kwargs)
|
||||
check_pos_int32(step_size, "step_size")
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue