dynamic shape check

pull/8578/head
wilfChen 4 years ago
parent edcb0cd86b
commit 2291b7f2e6

@ -251,7 +251,6 @@ class GpuKernel : public KernelMod {
device::DynamicKernelPtr dynamic_kernel_;
};
} // namespace kernel
} // namespace mindspore

@ -110,44 +110,57 @@ void ReplaceOutput(const FuncGraphPtr &graph, const AnfNodePtr &bn_grad, const A
manager->Replace(relu_grad, bn_add_relu_grad_output[kBNAddReluGradOutputNum - 1]);
return;
}
} // namespace
const BaseRef BatchNormAddReluGradFusion::DefinePattern() const {
VectorRef relu_grad = VectorRef({prim::kPrimReluGrad, dy_, y_});
VectorRef batch_norm_grad =
VectorRef({prim::kPrimFusedBatchNormGradEx, relu_grad, x_, scale_, save_mean_, save_var_, reserve_});
return batch_norm_grad;
}
const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
bool PatternCheck(const FuncGraphPtr &graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
auto format_attr = AnfAlgo::GetCNodePrimitive(node)->GetAttr("data_format");
MS_EXCEPTION_IF_NULL(format_attr);
auto format = GetValue<std::string>(format_attr);
if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") {
return nullptr;
return false;
}
auto relu_grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
MS_EXCEPTION_IF_NULL(relu_grad);
auto relu_users = GetRealNodeUsedList(graph, relu_grad);
if (relu_users->size() != 2) {
return nullptr;
return false;
}
// process pattern as Relu(TensorAdd(BN#0, BN#1))
auto tuple_getitem = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 5);
MS_EXCEPTION_IF_NULL(tuple_getitem);
if (!utils::isa<CNodePtr>(tuple_getitem) || AnfAlgo::GetCNodeName(tuple_getitem) != prim::kPrimTupleGetItem->name()) {
return nullptr;
return false;
}
auto forward_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_getitem), 0);
if (AnfAlgo::GetCNodeName(forward_node) != kFusedBatchNormExWithAddAndActivation) {
return false;
}
return true;
}
} // namespace
const BaseRef BatchNormAddReluGradFusion::DefinePattern() const {
VectorRef relu_grad = VectorRef({prim::kPrimReluGrad, dy_, y_});
VectorRef batch_norm_grad =
VectorRef({prim::kPrimFusedBatchNormGradEx, relu_grad, x_, scale_, save_mean_, save_var_, reserve_});
return batch_norm_grad;
}
const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
if (!PatternCheck(graph, node)) {
return nullptr;
}
auto relu_grad = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), 0);
MS_EXCEPTION_IF_NULL(relu_grad);
auto dy = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(relu_grad), 0);
MS_EXCEPTION_IF_NULL(dy);
auto y = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(relu_grad), 1);

@ -1432,7 +1432,7 @@ void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vector<tens
}
bool IsDynamicShape(const NotNull<abstract::ShapePtr> &shape) {
return std::any_of(shape->shape().begin(), shape->shape().end(), [](int s) { return s < 0; });
return std::any_of(shape->shape().begin(), shape->shape().end(), [](int64_t s) { return s < 0; });
}
bool IsNodeOutputDynamicShape(const CNodePtr &anf_node_ptr) {

Loading…
Cancel
Save