From 60759e85229802af5b9f005e5c49a814b485f391 Mon Sep 17 00:00:00 2001
From: huzhifeng <zhifeng.hu@huawei.com>
Date: Wed, 16 Sep 2020 15:17:47 +0800
Subject: [PATCH] add googlenet include top for hub

---
 model_zoo/official/cv/googlenet/src/googlenet.py | 14 +++++++++-----
 1 file changed, 9 insertions(+), 5 deletions(-)

diff --git a/model_zoo/official/cv/googlenet/src/googlenet.py b/model_zoo/official/cv/googlenet/src/googlenet.py
index 78695f2d6c..2ccf395487 100644
--- a/model_zoo/official/cv/googlenet/src/googlenet.py
+++ b/model_zoo/official/cv/googlenet/src/googlenet.py
@@ -81,7 +81,7 @@ class GoogleNet(nn.Cell):
     Googlenet architecture
     """
 
-    def __init__(self, num_classes):
+    def __init__(self, num_classes, include_top=True):
         super(GoogleNet, self).__init__()
         self.conv1 = Conv2dBlock(3, 64, kernel_size=7, stride=2, padding=0)
         self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
@@ -104,11 +104,13 @@ class GoogleNet(nn.Cell):
         self.block5a = Inception(832, 256, 160, 320, 32, 128, 128)
         self.block5b = Inception(832, 384, 192, 384, 48, 128, 128)
 
-        self.mean = P.ReduceMean(keep_dims=True)
         self.dropout = nn.Dropout(keep_prob=0.8)
-        self.flatten = nn.Flatten()
-        self.classifier = nn.Dense(1024, num_classes, weight_init=weight_variable(),
-                                   bias_init=weight_variable())
+        self.include_top = include_top
+        if self.include_top:
+            self.mean = P.ReduceMean(keep_dims=True)
+            self.flatten = nn.Flatten()
+            self.classifier = nn.Dense(1024, num_classes, weight_init=weight_variable(),
+                                       bias_init=weight_variable())
 
 
     def construct(self, x):
@@ -133,6 +135,8 @@ class GoogleNet(nn.Cell):
 
         x = self.block5a(x)
         x = self.block5b(x)
+        if not self.include_top:
+            return x
 
         x = self.mean(x, (2, 3))
         x = self.flatten(x)