|
|
|
@ -104,19 +104,11 @@ public:
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
size_t getFilterHeight(const TensorShape& filter) const {
|
|
|
|
|
if (filter.ndims() == 5) {
|
|
|
|
|
return filter[3];
|
|
|
|
|
} else {
|
|
|
|
|
return filter[2];
|
|
|
|
|
}
|
|
|
|
|
filter[filter.ndims() - 2];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t getFilterWidth(const TensorShape& filter) const {
|
|
|
|
|
if (filter.ndims() == 5) {
|
|
|
|
|
return filter[4];
|
|
|
|
|
} else {
|
|
|
|
|
return filter[3];
|
|
|
|
|
}
|
|
|
|
|
filter[filter.ndims() - 1];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> strides_;
|
|
|
|
|