!3024 Decode + RandomCropAndResize fusion within MapOp
Merge pull request !3024 from Alexey_Shevlyakov/random_crop_decode_resize_fusionpull/3024/MERGE
commit
530d46eb47
@ -0,0 +1,58 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "dataset/engine/opt/optional/tensor_op_fusion_pass.h"
|
||||
#include "dataset/kernels/image/decode_op.h"
|
||||
#include "dataset/engine/datasetops/map_op.h"
|
||||
#include "dataset/kernels/image/random_crop_decode_resize_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
Status TensorOpFusionPass::RunOnNode(std::shared_ptr<MapOp> node, bool *modified) {
|
||||
// Most primitive pattern: DecodeOp immediately followed by RandomCropAndResizeOp
|
||||
// Abstract into a more general member function that can find any pattern, expressed
|
||||
// by regular expressions, for instance.
|
||||
// Add a list of optimisation policies. For now, just this lambda
|
||||
auto FindPattern = [](auto &tfuncs) {
|
||||
auto it =
|
||||
std::find_if(tfuncs.begin(), tfuncs.end(), [](const auto &tf) -> bool { return tf->Name() == kDecodeOp; });
|
||||
auto next = it + 1;
|
||||
if (it != tfuncs.end() && next != tfuncs.end() && (*next)->Name() == kRandomCropAndResizeOp) {
|
||||
return it;
|
||||
} else {
|
||||
return tfuncs.end();
|
||||
}
|
||||
};
|
||||
|
||||
auto &tfuncs = node->TFuncs();
|
||||
auto it = FindPattern(tfuncs);
|
||||
if (it != tfuncs.end()) {
|
||||
auto next = it + 1;
|
||||
auto op = static_cast<RandomCropAndResizeOp *>(next->get());
|
||||
*it = std::static_pointer_cast<TensorOp>(std::make_shared<RandomCropDecodeResizeOp>(*op));
|
||||
tfuncs.erase(next);
|
||||
}
|
||||
if (modified != nullptr) {
|
||||
*modified = true;
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("modified is nullptr");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,38 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_TENSOR_OP_FUSION_PASS_H_
|
||||
#define DATASET_TENSOR_OP_FUSION_PASS_H_
|
||||
|
||||
#include <memory>
|
||||
#include "dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
/// \class TensorOpFusionPass tensor_op_fusion_pass.h
|
||||
/// \brief And optional optimization pass identifying and fusing
|
||||
/// tensor ops within MapOp
|
||||
class TensorOpFusionPass : public NodePass {
|
||||
/// \brief Identifies and fuses tensor ops within MapOp
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] *modified indicates whether the node has been visited
|
||||
/// \return Status The error code return
|
||||
Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified) override;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_TENSOR_OP_FUSION_PASS_H_
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue