reshape_2d used from ddim.h

test=develop
revert-13872-fix2
Sylwester Fraczek 6 years ago
parent 55d6950a1a
commit 50c5e9b0c6

@ -44,18 +44,6 @@ namespace ir {
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_mean, bn_saved_mean, pattern_name); \
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_variance, bn_saved_variance, pattern_name)
// reshape to two dimensions {A, B * C * ...}
DDim make_dims_2d(DDim dims) {
auto dims_count = dims.size();
PADDLE_ENFORCE_GT(dims_count, 0);
int size2 = 1;
for (int i = 1; i < dims_count; i++) {
size2 *= dims[i];
}
return make_ddim({dims[0], size2});
}
void recompute_bias_and_weights(const Scope* scope,
ir::Node* conv_weight, //
const ir::Node& bn_scale, //
@ -104,7 +92,7 @@ void recompute_bias_and_weights(const Scope* scope,
// Re-compute weight of conv2d from BN
auto* weights = scope->FindVar(conv_weight->Name())->GetMutable<LoDTensor>();
auto weights_shape = weights->dims();
auto weights_shape_2d = make_dims_2d(weights_shape);
auto weights_shape_2d = flatten_to_2d(weights_shape, 1);
EigenMatrixArrayMap weights_array_2d(
weights->mutable_data<float>(platform::CPUPlace()), weights_shape_2d[0],

Loading…
Cancel
Save