博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
MXNet动手学深度学习笔记:GoogLeNet神经网络实现
阅读量:7020 次
发布时间:2019-06-28

本文共 3880 字,大约阅读时间需要 12 分钟。

  hot3.png

#coding:utf-8from mxnet.gluon import nnfrom mxnet import ndimport sysimport ossys.path.append(os.getcwd())import utilsfrom mxnet import gluonfrom mxnet import initclass Inception(nn.Block):    def __init__(self,n1_1,n2_1,n2_3,n3_1,n3_5,n4_1,**kwargs):        super(Inception,self).__init__(**kwargs)        # path 1        self.p1_conv_1 = nn.Conv2D(n1_1,kernel_size = 1,                            activation='relu')        # path 2        self.p2_conv_1 = nn.Conv2D(n2_1,kernel_size = 1,                            activation='relu')        self.p2_conv_3 = nn.Conv2D(n2_3,kernel_size=3,                        padding=1,activation='relu')        # path 3        self.p3_conv_1 = nn.Conv2D(n3_1,kernel_size=1,                            activation='relu')        self.p3_conv_5 = nn.Conv2D(n3_5,kernel_size=5,padding=2,                        activation='relu')        # path 4        self.p4_pool_3 = nn.MaxPool2D(pool_size=3,padding=1,strides=1)        self.p4_conv_1 = nn.Conv2D(n4_1,kernel_size=1,activation='relu')    def forward(self,x):        p1 = self.p1_conv_1(x)        p2 = self.p2_conv_3(self.p2_conv_1(x))        p3 = self.p3_conv_5(self.p3_conv_1(x))        p4 = self.p4_conv_1(self.p4_pool_3(x))        return nd.concat(p1,p2,p3,p4,dim=1)incp = Inception(64,96,128,16,32,32)incp.initialize()x = nd.random.uniform(shape=(32,3,64,64))result = incp(x)print(result.shape)class GoogLeNet(nn.Block):    def __init__(self,num_classes,verbose=False,**kwargs):        super(GoogLeNet,self).__init__(**kwargs)        self.verbose = verbose        with self.name_scope():            # block 1            b1 = nn.Sequential()            b1.add(                nn.Conv2D(64,kernel_size=7,strides=2,                padding=3,activation='relu'),                nn.MaxPool2D(pool_size=3,strides=2)            )            # block 2            b2 = nn.Sequential()            b2.add(                nn.Conv2D(64,kernel_size=1),                nn.Conv2D(192,kernel_size=3,padding=1),                nn.MaxPool2D(pool_size=3,strides=2)            )            # block 3            b3 = nn.Sequential()            b3.add(                Inception(64,96,128,16,32,32),                Inception(128,128,192,32,94,64),                nn.MaxPool2D(pool_size=3,strides=2)            )            # block 4            b4 = nn.Sequential()            b4.add(                Inception(192, 96, 208, 16, 48, 64),                Inception(160, 112, 224, 24, 64, 64),                Inception(128, 128, 256, 24, 64, 64),                Inception(112, 144, 288, 32, 64, 64),                Inception(256, 160, 320, 32, 128, 128),                nn.MaxPool2D(pool_size=3, strides=2)            )            # block 5            b5 = nn.Sequential()            b5.add(            Inception(256, 160, 320, 32, 128, 128),            Inception(384, 192, 384, 48, 128, 128),            nn.AvgPool2D(pool_size=2)            )            # block 6            b6 = nn.Sequential()            b6.add(            nn.Flatten(),            nn.Dense(num_classes)            )            # chain blocks together            self.net = nn.Sequential()            self.net.add(b1, b2, b3, b4, b5, b6)    def forward(self,x):        out = x        for i ,b in enumerate(self.net):            out = b(out)            if self.verbose:                print('Block %d output:%s' %(i + 1,out.shape))        return outnet = GoogLeNet(10, verbose=True)net.initialize()x = nd.random.uniform(shape=(4, 3, 96, 96))y = net(x)train_data, test_data = utils.load_data_fashion_mnist(batch_size=64, resize=96)ctx = utils.try_gpu()net = GoogLeNet(10)net.initialize(ctx=ctx, init=init.Xavier())loss = gluon.loss.SoftmaxCrossEntropyLoss()trainer = gluon.Trainer(net.collect_params(),'sgd', {'learning_rate': 0.01})utils.train(train_data, test_data, net, loss,trainer, ctx, num_epochs=1)

 

转载于:https://my.oschina.net/wujux/blog/1809886

你可能感兴趣的文章
【java IO】使用Java输入输出流 读取txt文件内数据,进行拼接后写入到另一个文件中...
查看>>
【Go入门教程7】面向对象(method、指针作为receiver、method继承、method重写)
查看>>
iOS开发小技巧 -- tableView-section圆角边框解决方案
查看>>
URL动态赋值
查看>>
adb shell dumpsys的使用
查看>>
设计模式之策略模式
查看>>
[LintCode] Longest Common Prefix 最长共同前缀
查看>>
redis-cli中那些或许我们还不知道的一些实用小功能
查看>>
下载历史版本App超详细教程
查看>>
数据标准化(转)
查看>>
常见的开源日志(包括分布式)
查看>>
MongoDb gridfs-ngnix文件存储方案
查看>>
ArcEngine数据编辑--选择要素
查看>>
微信小程序正式发布!这是最全的上手指南
查看>>
Linux下ls命令使用详解(转)
查看>>
第十四篇:Apriori 关联分析算法原理分析与代码实现
查看>>
putty简单使用
查看>>
B A T 开源项目(转载)
查看>>
Java主流Web Service框架介绍:CXF和Axis2
查看>>
C++对象模型-构造函数语意学
查看>>