{"id":403,"date":"2022-08-07T15:19:11","date_gmt":"2022-08-07T07:19:11","guid":{"rendered":"http:\/\/www.gislxz.top\/?p=403"},"modified":"2022-08-14T16:16:43","modified_gmt":"2022-08-14T08:16:43","slug":"%e6%b7%b1%e5%ba%a6%e5%ad%a6%e4%b9%a0%e7%ac%94%e8%ae%b0%ef%bc%8810%ef%bc%89","status":"publish","type":"post","link":"https:\/\/www.gislxz.com\/index.php\/2022\/08\/07\/%e6%b7%b1%e5%ba%a6%e5%ad%a6%e4%b9%a0%e7%ac%94%e8%ae%b0%ef%bc%8810%ef%bc%89\/","title":{"rendered":"\u6df1\u5ea6\u5b66\u4e60\u7b14\u8bb0\uff0810\uff09"},"content":{"rendered":"\n<p>paddle\u5b98\u65b9\u4e0a\u4e00\u7ae0\u6559\u7a0b\u4ecb\u7ecd\u4e86\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\u7684\u4e00\u4e9b\u57fa\u672c\u6982\u5ff5\u548c\u6570\u5b66\u539f\u7406\uff0c\u8fd9\u4e00\u7ae0\u5c31\u4ee5\u56fe\u50cf\u5206\u7c7b\u4efb\u52a1\u5165\u624b\u4ecb\u7ecd\u5404\u79cd\u7ecf\u5178\u5377\u79ef\u7f51\u7edc\u7ed3\u6784<\/p>\n\n\n\n<p>\u5b98\u65b9\u6559\u7a0b\u94fe\u63a5\uff1a<a href=\"https:\/\/aistudio.baidu.com\/aistudio\/projectdetail\/1613144\" target=\"_blank\"  rel=\"nofollow\" >\u56fe\u50cf\u5206\u7c7b<\/a><\/p>\n\n\n\n<p>\u9996\u5148\u662fLeNet\uff0c\u8fd9\u4e2a\u7f51\u7edc\u7ed3\u6784\u76f8\u5bf9\u6bd4\u8f83\u7b80\u5355\uff0c\u548c\u4e4b\u524d\u624b\u5199\u6570\u5b57\u7684\u7f51\u7edc\u7ed3\u6784\u76f8\u4f3c\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code># \u5bfc\u5165\u9700\u8981\u7684\u5305\nimport paddle\nimport numpy as np\nfrom paddle.nn import Conv2D, MaxPool2D, Linear\n\n## \u7ec4\u7f51\nimport paddle.nn.functional as F\n\n# \u5b9a\u4e49 LeNet \u7f51\u7edc\u7ed3\u6784\nclass LeNet(paddle.nn.Layer):\n    def __init__(self, num_classes=1):\n        super(LeNet, self).__init__()\n        # \u521b\u5efa\u5377\u79ef\u548c\u6c60\u5316\u5c42\u5757\uff0c\n        # \u521b\u5efa\u7b2c1\u4e2a\u5377\u79ef\u5c42\n        self.conv1 = Conv2D(in_channels=1, out_channels=6, kernel_size=5)\n        self.max_pool1 = MaxPool2D(kernel_size=2, stride=2)\n        # # \u521b\u5efa\u7b2c2\u4e2a\u5377\u79ef\u5c42\uff1b\u5c3a\u5bf8\u7684\u903b\u8f91\uff1a\u6c60\u5316\u5c42\u672a\u6539\u53d8\u901a\u9053\u6570\uff1b\u8f93\u51fa\u901a\u9053\u7b49\u4e8e\u5377\u79ef\u6838\u7684\u6570\u91cf\uff0c\u8be5\u5c42\u51716*16\u4e2a\u5377\u79ef\u6838\n        self.conv2 = Conv2D(in_channels=6, out_channels=16, kernel_size=5)\n        self.max_pool2 = MaxPool2D(kernel_size=2, stride=2)\n        # \u521b\u5efa\u7b2c3\u4e2a\u5377\u79ef\u5c42\n        self.conv3 = Conv2D(in_channels=16, out_channels=120, kernel_size=4)\n        # \u5c3a\u5bf8\u7684\u903b\u8f91\uff1a\u8f93\u5165\u5c42\u5c06\u6570\u636e\u62c9\u5e73&#91;B,C,H,W] -&gt; &#91;B,C*H*W]\n        # \u8f93\u5165size\u662f&#91;32,32]\uff0c\u7ecf\u8fc7\u4e09\u6b21\u5377\u79ef\u548c\u4e24\u6b21\u6c60\u5316\u4e4b\u540e\uff0cC*H*W\u7b49\u4e8e120\n        self.fc1 = Linear(in_features=120, out_features=64)\n        # \u521b\u5efa\u5168\u8fde\u63a5\u5c42\uff0c\u7b2c\u4e00\u4e2a\u5168\u8fde\u63a5\u5c42\u7684\u8f93\u51fa\u795e\u7ecf\u5143\u4e2a\u6570\u4e3a64\uff0c \u7b2c\u4e8c\u4e2a\u5168\u8fde\u63a5\u5c42\u8f93\u51fa\u795e\u7ecf\u5143\u4e2a\u6570\u4e3a\u5206\u7c7b\u6807\u7b7e\u7684\u7c7b\u522b\u6570\n        self.fc2 = Linear(in_features=64, out_features=num_classes)\n    # \u7f51\u7edc\u7684\u524d\u5411\u8ba1\u7b97\u8fc7\u7a0b\n    def forward(self, x):\n        x = self.conv1(x)\n        # \u6bcf\u4e2a\u5377\u79ef\u5c42\u4f7f\u7528Sigmoid\u6fc0\u6d3b\u51fd\u6570\uff0c\u540e\u9762\u8ddf\u7740\u4e00\u4e2a2x2\u7684\u6c60\u5316\n        x = F.sigmoid(x)\n        x = self.max_pool1(x)\n        x = F.sigmoid(x)\n        x = self.conv2(x)\n        x = self.max_pool2(x)\n        x = self.conv3(x)\n        # \u5c3a\u5bf8\u7684\u903b\u8f91\uff1a\u8f93\u5165\u5c42\u5c06\u6570\u636e\u62c9\u5e73&#91;B,C,H,W] -&gt; &#91;B,C*H*W]\n        x = paddle.reshape(x, &#91;x.shape&#91;0], -1])\n        x = self.fc1(x)\n        x = F.sigmoid(x)\n        x = self.fc2(x)\n        return x<\/code><\/pre>\n\n\n\n<figure class=\"wp-block-image\"><img decoding=\"async\" src=\"https:\/\/ai-studio-static-online.cdn.bcebos.com\/33bbff96924e4b36b613f0c1c36a89dfb72e3b56b3be464dbbce22f7ce575b0d\" alt=\"\"\/><\/figure>\n\n\n\n<p>\u5728\u5b66\u4e60\u522b\u4eba\u63d0\u51fa\u7684\u7ecf\u5178\u7f51\u7edc\u7ed3\u6784\u65f6\uff0c\u6211\u4eec\u5e94\u8be5\u81ea\u5df1\u60f3\u60f3\u6bcf\u4e00\u5c42\u8f93\u51fa\u7684\u77e9\u9635shape\u662f\u4ec0\u4e48\uff0c\u5e76\u4e14\u7528\u968f\u673a\u6570\u4f5c\u4e3a\u8f93\u5165\uff0c\u67e5\u770b\u7ecf\u8fc7LeNet-5\u7684\u6bcf\u4e00\u5c42\u4f5c\u7528\u4e4b\u540e\uff0c\u8f93\u51fa\u6570\u636e\u7684\u5f62\u72b6\u6765\u9a8c\u8bc1\u81ea\u5df1\u7684\u8ba1\u7b97\u662f\u5426\u6b63\u786e\u3002<\/p>\n\n\n\n<p>\u5982\u679c\u8f93\u5165\u662f28\u00d728\u7684\u56fe\u50cf\uff0c\u7b2c\u4e00\u5c42\u5377\u79ef\u662f5\u00d75\u7684\u5377\u79ef\u6838\uff0c\u51716\u4e2a\u901a\u9053\uff0c\u6b65\u957f\u9ed8\u8ba4\u662f1\uff0c\u90a3\u4e48\u5377\u79ef\u540e\u5f97\u5230\u7684\u7ed3\u679c\u5c31\u662f24\u00d724\u00d76\u901a\u9053<\/p>\n\n\n\n<p>\u4e4b\u540e2\u00d72\u6c60\u5316\uff0c\u6b65\u957f2\uff0c\u8f93\u51fa\u5c3a\u5bf8\u51cf\u534a\uff0c\u901a\u9053\u4e0d\u53d8\uff0c\u7ed3\u679c\u662f12\u00d712\u00d76<\/p>\n\n\n\n<p>\u7b2c\u4e09\u5c42\u540c\u6837\u7684\u5377\u79ef\uff0c\u540c\u68375\u00d75\uff0c\u8f93\u5165\u901a\u90536\uff0c\u8f93\u51fa\u901a\u905316\uff08\u8f93\u5165\u901a\u9053\u8981\u6ce8\u610f\u548c\u4e0a\u4e00\u5c42\u7684\u8f93\u51fa\u7ed3\u679c\u901a\u9053\u6570\u4e00\u81f4\uff09\uff0c\u6b65\u957f\u8fd8\u662f1\uff0c\u8f93\u51fa\u7ed3\u679c\u5c31\u662f8\u00d78\u00d716<\/p>\n\n\n\n<p>\u7b2c\u56db\u5c42\u4e00\u6837\u7684\u6c60\u5316\uff0c\u5c3a\u5bf8\u51cf\u534a\uff0c\u8f93\u51fa4\u00d74\u00d716<\/p>\n\n\n\n<p>\u7b2c\u4e94\u5c42\uff0c\u6700\u540e\u4e00\u6b21\u5377\u79ef\uff0c\u8f6c\u4e3a120\u901a\u9053\uff0c\u8f93\u51fa1\u00d71\u00d7120\u3002\u6ce8\u610f\u8fd9\u91cc\u56e0\u4e3a\u8f93\u5165\u662f28\u00d728\uff0c\u6240\u4ee5\u6700\u540e\u4e00\u5c42\u5377\u79ef\u7684kernel\u662f4\u00d74\u7684\u5c3a\u5bf8\u3002\u8981\u6c42\u6700\u540e\u5377\u79ef\u5b8c\u662f1\u00d71\u7684size\u597d\u540e\u9762\u8fdb\u884c\u7ebf\u6027\u8ba1\u7b97\u3002\u5982\u679c\u8f93\u5165\u56fe\u50cf\u662f32\u621032\uff0c\u4e5f\u5c31\u662f\u548c\u4e0a\u9762\u793a\u610f\u56fe\u4e00\u6837\uff0c\u90a3\u4e48\u6bcf\u5c42\u7684size\u5c31\u53d8\u6210\u4e86\u301028\u00d728\u00d76-&gt;14\u00d714\u00d76-&gt;10\u00d710\u00d716-&gt;5\u00d75\u00d716\u3011\u90a3\u4e48\u6700\u540e\u4e00\u5c42\u5377\u79ef\u5c31\u8981\u628a\u5377\u79ef\u6838\uff08kernel\uff09\u7684size\u6539\u62105\u00d75\u3002<\/p>\n\n\n\n<p>\u4e4b\u540e\u4e24\u5c42\u7ebf\u6027\u5316\u5c31\u662f\u628a\u5148\u628a1\u00d71\u00d7120\u7684\u8f93\u51fa\u964d\u7ef4\u6210120\u7684\u6570\u7ec4\uff0c\u7136\u540e\u7ebf\u6027\u8ba1\u7b97\u7b2c\u4e00\u6b21\u523064\uff0c\u7b2c\u4e8c\u6b21\u523010\uff08num_class\uff0c\u7c7b\u522b\u6570\uff09<\/p>\n\n\n\n<p>\u5f53\u7136\u4e0a\u9762\u90fd\u6ca1\u7b97batch-size\uff0cbatch-size\u76f4\u63a5\u653e\u524d\u9762\u5c31\u884c\uff0c\u6ca1\u6709\u5f71\u54cd\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code># \u8f93\u5165\u6570\u636e\u5f62\u72b6\u662f &#91;N, 1, H, W]\n# \u8fd9\u91cc\u7528np.random\u521b\u5efa\u4e00\u4e2a\u968f\u673a\u6570\u7ec4\u4f5c\u4e3a\u8f93\u5165\u6570\u636e\nx = np.random.randn(*&#91;3,1,28,28])\nx = x.astype('float32')\n\n# \u521b\u5efaLeNet\u7c7b\u7684\u5b9e\u4f8b\uff0c\u6307\u5b9a\u6a21\u578b\u540d\u79f0\u548c\u5206\u7c7b\u7684\u7c7b\u522b\u6570\u76ee\nm = LeNet(num_classes=10)\n# \u901a\u8fc7\u8c03\u7528LeNet\u4ece\u57fa\u7c7b\u7ee7\u627f\u7684sublayers()\u51fd\u6570\uff0c\n# \u67e5\u770bLeNet\u4e2d\u6240\u5305\u542b\u7684\u5b50\u5c42\nprint(m.sublayers())\nx = paddle.to_tensor(x)\nfor item in m.sublayers():\n    # item\u662fLeNet\u7c7b\u4e2d\u7684\u4e00\u4e2a\u5b50\u5c42\n    # \u67e5\u770b\u7ecf\u8fc7\u5b50\u5c42\u4e4b\u540e\u7684\u8f93\u51fa\u6570\u636e\u5f62\u72b6\n    try:\n        x = item(x)\n    except:\n        x = paddle.reshape(x, &#91;x.shape&#91;0], -1])\n        x = item(x)\n    if len(item.parameters())==2:\n        # \u67e5\u770b\u5377\u79ef\u548c\u5168\u8fde\u63a5\u5c42\u7684\u6570\u636e\u548c\u53c2\u6570\u7684\u5f62\u72b6\uff0c\n        # \u5176\u4e2ditem.parameters()&#91;0]\u662f\u6743\u91cd\u53c2\u6570w\uff0citem.parameters()&#91;1]\u662f\u504f\u7f6e\u53c2\u6570b\n        print(item.full_name(), x.shape, item.parameters()&#91;0].shape, item.parameters()&#91;1].shape)\n    else:\n        # \u6c60\u5316\u5c42\u6ca1\u6709\u53c2\u6570\n        print(item.full_name(), x.shape)<\/code><\/pre>\n\n\n\n<pre class=\"wp-block-preformatted\">[Conv2D(1, 6, kernel_size=[5, 5], data_format=NCHW), MaxPool2D(kernel_size=2, stride=2, padding=0), Conv2D(6, 16, kernel_size=[5, 5], data_format=NCHW), MaxPool2D(kernel_size=2, stride=2, padding=0), Conv2D(16, 120, kernel_size=[4, 4], data_format=NCHW), Linear(in_features=120, out_features=64, dtype=float32), Linear(in_features=64, out_features=10, dtype=float32)]\nconv2d_6 [3, 6, 24, 24] [6, 1, 5, 5] [6]\nmax_pool2d_4 [3, 6, 12, 12]\nconv2d_7 [3, 16, 8, 8] [16, 6, 5, 5] [16]\nmax_pool2d_5 [3, 16, 4, 4]\nconv2d_8 [3, 120, 1, 1] [120, 16, 4, 4] [120]\nlinear_4 [3, 64] [120, 64] [64]\nlinear_5 [3, 10] [64, 10] [10]<\/pre>\n\n\n\n<p>\u4e0e\u6211\u4eec\u7684\u8ba1\u7b97\u4e00\u81f4<\/p>\n\n\n\n<p>\u8fd9\u91cc\u5377\u79ef\u6838\u7684\u53c2\u6570\u4e5f\u8981\u7740\u91cd\u7406\u89e3\u4e00\u4e0b\uff0c\u7b2c\u4e00\u5c42\u5c31\u662f\u516d\u4e2a5\u00d75\u7684\u5377\u79ef\u6838\uff0c\u5f88\u7b80\u5355\u3002<\/p>\n\n\n\n<p>\u7b2c\u4e8c\u5c42\u5377\u79ef\u662f6\u901a\u9053-&gt;16\u901a\u9053\uff0c\u8f93\u51fa\u7684\u6bcf\u4e00\u4e2a\u901a\u9053\u90fd\u662f\u4e0a\u4e00\u6b65\u7ed3\u679c6\u4e2a\u901a\u9053\u5377\u79ef\u51fa\u6765\u5728\u76f8\u52a0\u7684\u7ed3\u679c\uff0c\u4e5f\u5c31\u662f\u4e00\u51716\u00d716\u4e2a\u5377\u79ef\u6838\uff0c\u53ef\u4ee5\u7ed3\u5408\u4e0b\u56fe\u7406\u89e3<\/p>\n\n\n\n<p>\u591a\u901a\u9053\u8f93\u5165<\/p>\n\n\n\n<figure class=\"wp-block-image\"><img decoding=\"async\" src=\"https:\/\/ai-studio-static-online.cdn.bcebos.com\/92186667b8424a7ca781b22de6766fa62e31512cf2e24e33a4b796541177c9dd\" alt=\"\"\/><\/figure>\n\n\n\n<p>\u591a\u901a\u9053\u8f93\u5165\uff0b\u591a\u901a\u9053\u8f93\u51fa<\/p>\n\n\n\n<figure class=\"wp-block-image\"><img decoding=\"async\" src=\"https:\/\/ai-studio-static-online.cdn.bcebos.com\/cf1fbddc141349e4b7aaeade9a201b78a16d249e069c4f8aaeb77e0ea1a95c31\" alt=\"\"\/><\/figure>\n\n\n\n<p>\u4e0b\u9762\u8fdb\u884c\u8bad\u7ec3\uff0c\u548c\u4e4b\u524d\u7684\u624b\u5199\u6570\u5b57\u8bc6\u522b\u7684\u4ee3\u7801\u6ca1\u6709\u533a\u522b\uff0c\u53ea\u662fmodel\u7684\u7ed3\u6784\u6539\u6210\u4e86LeNet<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code># -*- coding: utf-8 -*-\n# LeNet \u8bc6\u522b\u624b\u5199\u6570\u5b57\nimport os\nimport random\nimport paddle\nimport numpy as np\n\n# \u5b9a\u4e49\u8bad\u7ec3\u8fc7\u7a0b\ndef train(model):\n\n    # \u5f00\u542f0\u53f7GPU\u8bad\u7ec3\n    use_gpu = True\n    paddle.set_device('gpu:0') if use_gpu else paddle.set_device('cpu')\n    print('start training ... ')\n    model.train()\n    epoch_num = 5\n    opt = paddle.optimizer.Momentum(learning_rate=0.001, momentum=0.9, parameters=model.parameters())\n\n    # \u4f7f\u7528Paddle\u81ea\u5e26\u7684\u6570\u636e\u8bfb\u53d6\u5668\n    train_loader = paddle.batch(paddle.dataset.mnist.train(), batch_size=10)\n    valid_loader = paddle.batch(paddle.dataset.mnist.test(), batch_size=10)\n    for epoch in range(epoch_num):\n        for batch_id, data in enumerate(train_loader()):\n            # \u8c03\u6574\u8f93\u5165\u6570\u636e\u5f62\u72b6\u548c\u7c7b\u578b\n            x_data = np.array(&#91;item&#91;0] for item in data], dtype='float32').reshape(-1, 1, 28, 28)\n            y_data = np.array(&#91;item&#91;1] for item in data], dtype='int64').reshape(-1, 1)\n            # \u5c06numpy.ndarray\u8f6c\u5316\u6210Tensor\n            img = paddle.to_tensor(x_data)\n            label = paddle.to_tensor(y_data)\n            # \u8ba1\u7b97\u6a21\u578b\u8f93\u51fa\n            logits = model(img)\n            # \u8ba1\u7b97\u635f\u5931\u51fd\u6570\n            loss = F.softmax_with_cross_entropy(logits, label)\n            avg_loss = paddle.mean(loss)\n\n            if batch_id % 1000 == 0:\n                print(\"epoch: {}, batch_id: {}, loss is: {}\".format(epoch, batch_id, avg_loss.numpy()))\n            avg_loss.backward()\n            opt.step()\n            opt.clear_grad()\n\n        model.eval()\n        accuracies = &#91;]\n        losses = &#91;]\n        for batch_id, data in enumerate(valid_loader()):\n            # \u8c03\u6574\u8f93\u5165\u6570\u636e\u5f62\u72b6\u548c\u7c7b\u578b\n            x_data = np.array(&#91;item&#91;0] for item in data], dtype='float32').reshape(-1, 1, 28, 28)\n            y_data = np.array(&#91;item&#91;1] for item in data], dtype='int64').reshape(-1, 1)\n            # \u5c06numpy.ndarray\u8f6c\u5316\u6210Tensor\n            img = paddle.to_tensor(x_data)\n            label = paddle.to_tensor(y_data)\n            # \u8ba1\u7b97\u6a21\u578b\u8f93\u51fa\n            logits = model(img)\n            pred = F.softmax(logits)\n            # \u8ba1\u7b97\u635f\u5931\u51fd\u6570\n            loss = F.softmax_with_cross_entropy(logits, label)\n            acc = paddle.metric.accuracy(pred, label)\n            accuracies.append(acc.numpy())\n            losses.append(loss.numpy())\n        print(\"&#91;validation] accuracy\/loss: {}\/{}\".format(np.mean(accuracies), np.mean(losses)))\n        model.train()\n\n    # \u4fdd\u5b58\u6a21\u578b\u53c2\u6570\n    paddle.save(model.state_dict(), 'mnist.pdparams')\n# \u521b\u5efa\u6a21\u578b\nmodel = LeNet(num_classes=10)\n# \u542f\u52a8\u8bad\u7ec3\u8fc7\u7a0b\ntrain(model)<\/code><\/pre>\n\n\n\n<p>\u63a5\u4e0b\u6765\u6211\u4eec\u7528torch\u6765\u5b9e\u73b0\u4e00\u904d<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>import torch\nfrom torch import nn\nfrom torch import optim\nimport torch.nn.functional as F\nimport torchvision\nimport json\nimport numpy as np\nimport gzip\nimport random\n\n#paddle\u8fd9\u91cc\u76f4\u63a5\u7528\u7684\u81ea\u5e26\u7684\u6570\u636e\u8bfb\u53d6\u5668\u3002\u6211\u4eec\u8fd8\u662f\u7528\u81ea\u5df1\u5199\u7684\u6570\u636e\u8bfb\u53d6\u51fd\u6570\ndef load_data(mode='train'):\n    datafile = r'E:NLPDATA\/MNIST\/mnist.json.gz'\n    print('loading mnist dataset from {} ......'.format(datafile))\n    # \u52a0\u8f7djson\u6570\u636e\u6587\u4ef6\n    data = json.load(gzip.open(datafile))\n    print('mnist dataset load done')\n\n    # \u8bfb\u53d6\u5230\u7684\u6570\u636e\u533a\u5206\u8bad\u7ec3\u96c6\uff0c\u9a8c\u8bc1\u96c6\uff0c\u6d4b\u8bd5\u96c6\n    train_set, val_set, eval_set = data\n    if mode == 'train':\n        # \u83b7\u5f97\u8bad\u7ec3\u6570\u636e\u96c6\n        imgs, labels = train_set&#91;0], train_set&#91;1]\n    elif mode == 'valid':\n        # \u83b7\u5f97\u9a8c\u8bc1\u6570\u636e\u96c6\n        imgs, labels = val_set&#91;0], val_set&#91;1]\n    elif mode == 'eval':\n        # \u83b7\u5f97\u6d4b\u8bd5\u6570\u636e\u96c6\n        imgs, labels = eval_set&#91;0], eval_set&#91;1]\n    else:\n        raise Exception(\"mode can only be one of &#91;'train', 'valid', 'eval']\")\n    print(\"\u8bad\u7ec3\u6570\u636e\u96c6\u6570\u91cf: \", len(imgs))\n\n    # \u6821\u9a8c\u6570\u636e\n    imgs_length = len(imgs)\n\n    assert len(imgs) == len(labels), \\\n          \"length of train_imgs({}) should be the same as train_labels({})\".format(len(imgs), len(labels))\n\n    # \u83b7\u5f97\u6570\u636e\u96c6\u957f\u5ea6\n    imgs_length = len(imgs)\n    # \u5b9a\u4e49\u6570\u636e\u96c6\u6bcf\u4e2a\u6570\u636e\u7684\u5e8f\u53f7\uff0c\u6839\u636e\u5e8f\u53f7\u8bfb\u53d6\u6570\u636e\n    index_list = list(range(imgs_length))\n    # \u8bfb\u5165\u6570\u636e\u65f6\u7528\u5230\u7684\u6279\u6b21\u5927\u5c0f\n    BATCHSIZE = 100\n\n    # \u5b9a\u4e49\u6570\u636e\u751f\u6210\u5668\n    def data_generator():\n        if mode == 'train':\n            # \u8bad\u7ec3\u6a21\u5f0f\u4e0b\u6253\u4e71\u6570\u636e\n            random.shuffle(index_list)\n        imgs_list = &#91;]\n        labels_list = &#91;]\n        for i in index_list:\n            # \u5c06\u6570\u636e\u5904\u7406\u6210\u5e0c\u671b\u7684\u683c\u5f0f\uff0c\u6bd4\u5982\u7c7b\u578b\u4e3afloat32\uff0cshape\u4e3a&#91;1, 28, 28]\n            img = np.reshape(imgs&#91;i], &#91;1, 28, 28]).astype('float32')\n            label = np.reshape(labels&#91;i], &#91;1]).astype('int64')\n            imgs_list.append(img)\n            labels_list.append(label)\n            if len(imgs_list) == BATCHSIZE:\n                # \u83b7\u5f97\u4e00\u4e2abatchsize\u7684\u6570\u636e\uff0c\u5e76\u8fd4\u56de\n                yield np.array(imgs_list), np.array(labels_list)\n                # \u6e05\u7a7a\u6570\u636e\u8bfb\u53d6\u5217\u8868\n                imgs_list = &#91;]\n                labels_list = &#91;]\n\n        # \u5982\u679c\u5269\u4f59\u6570\u636e\u7684\u6570\u76ee\u5c0f\u4e8eBATCHSIZE\uff0c\n        # \u5219\u5269\u4f59\u6570\u636e\u4e00\u8d77\u6784\u6210\u4e00\u4e2a\u5927\u5c0f\u4e3alen(imgs_list)\u7684mini-batch\n        if len(imgs_list) &gt; 0:\n            yield np.array(imgs_list), np.array(labels_list)\n\n    return data_generator\n\n#\u5b9a\u4e49\u6a21\u578b\u7ed3\u6784\nfrom torch.nn import Conv2d, MaxPool2d, Linear\nclass Mnist(nn.Module):\n    def __init__(self):\n        super(Mnist, self).__init__()\n        self.conv1 = Conv2d(in_channels=1, out_channels=6, kernel_size=5)\n        self.max_pool1 = MaxPool2d(kernel_size=2, stride=2)\n        self.conv2 = Conv2d(in_channels=6, out_channels=16, kernel_size=5)\n        self.max_pool2 = MaxPool2d(kernel_size=2, stride=2)\n        self.conv3 = Conv2d(in_channels=16, out_channels=120, kernel_size=4)\n        self.fc1 = Linear(in_features=120, out_features=64)\n        self.fc2 = Linear(in_features=64, out_features=10)\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = torch.sigmoid(x)\n        x = self.max_pool1(x)\n        x = self.conv2(x)\n        x = torch.sigmoid(x)\n        x = self.max_pool2(x)\n        x = self.conv3(x)\n        x = torch.sigmoid(x)\n        x = torch.reshape(x, &#91;x.shape&#91;0], -1])\n        x = self.fc1(x)\n        x = torch.sigmoid(x)\n        x = self.fc2(x)\n        return x\n\n# \u8bad\u7ec3\u914d\u7f6e\uff0c\u5e76\u542f\u52a8\u8bad\u7ec3\u8fc7\u7a0b\nmodel = Mnist()\nmodel = model.cuda()\nmodel.train(mode=True)\n#\u8c03\u7528\u52a0\u8f7d\u6570\u636e\u7684\u51fd\u6570\ntrain_loader = load_data('train')\noptimizer = optim.Adam(model.parameters(), lr=0.001)\n\nBATCHSIZE = 100\nEPOCH_NUM = 10\nfor epoch_id in range(EPOCH_NUM):\n    correct = 0\n    for batch_id, data in enumerate(train_loader()):\n        #\u51c6\u5907\u6570\u636e\uff0c\u53d8\u5f97\u66f4\u52a0\u7b80\u6d01\n        image_data, label_data = data\n        image = torch.tensor(image_data).cuda()\n        label = torch.tensor(label_data).cuda()\n        #image = torch.reshape(image, &#91;image.shape&#91;0], 1, 28, 28])\n        #\u524d\u5411\u8ba1\u7b97\u7684\u8fc7\u7a0b\n        if batch_id == 0 and epoch_id == 0:\n            predict = model(image)\n        elif batch_id == 401:\n            predict = model(image)\n        else:\n            predict = model(image)\n        predict_label = torch.max(predict, 1)&#91;1]\n        correct += (predict_label == label.squeeze(dim=1)).sum()\n        #\u8ba1\u7b97\u635f\u5931\uff0c\u53d6\u4e00\u4e2a\u6279\u6b21\u6837\u672c\u635f\u5931\u7684\u5e73\u5747\u503c\n        loss = F.cross_entropy(predict, label.squeeze(dim=1)).cuda()\n        avg_loss = torch.mean(loss)\n        #\u6bcf\u8bad\u7ec3\u4e86150\u6279\u6b21\u7684\u6570\u636e\uff0c\u6253\u5370\u4e0b\u5f53\u524dLoss\u7684\u60c5\u51b5\n        if batch_id != 0 and batch_id % 150 == 0:\n            print(\"epoch: {}, batch: {}, loss is: {}, Accuracy:{:.3f}\".format(epoch_id, batch_id, avg_loss.cpu().detach().numpy(),\\\n                                                                              correct\/(BATCHSIZE*150)))\n            correct = 0\n        #\u540e\u5411\u4f20\u64ad\uff0c\u66f4\u65b0\u53c2\u6570\u7684\u8fc7\u7a0b\n        avg_loss.backward()\n        optimizer.step()\n        model.zero_grad()\ntorch.save(model.state_dict(), 'mnist_test.pdparams')\/\/torch\u7684\u4fdd\u5b58state_dict\u4e00\u5b9a\u4e00\u5b9a\u8981\u52a0\u62ec\u53f7\uff0ctorch\u8981\u4e0d\u8981\u52a0\u62ec\u53f7\u771f\u7684\u8ff7<\/code><\/pre>\n\n\n\n<p>\u8fd9\u91cc\u6ca1\u6709\u8f93\u51fa\u6bcf\u4e2aepoch\u5f00\u59cb\uff0cbatch_id=0\u65f6\u7684accuracy\uff0c\u4e3b\u8981\u662f\u56fe\u7701\u4e8b\uff0c\u6bcf\u4e2aepoch\u6700\u540e\u4e00\u7ec4\u4e0d\u8db3150batch\u7684\u4e0d\u592a\u65b9\u4fbf\u7b97\uff0c\u5e94\u8be5\u7528\u4e2alist\u628a\u6bcf\u4e2abatch\u7684accuracy\u88c5\u8d77\u6765\u6c42\u4e2amean\uff0c\u61d2\u5f97\u6539\u4e86\u3002<\/p>\n\n\n\n<div class=\"wp-block-image\"><figure class=\"aligncenter size-large\"><img loading=\"lazy\" decoding=\"async\" width=\"1024\" height=\"576\" src=\"http:\/\/www.gislxz.com\/wp-content\/uploads\/2022\/08\/IMG_20220807_025451-1024x576.png\" alt=\"\" class=\"wp-image-404\" srcset=\"https:\/\/www.gislxz.com\/wp-content\/uploads\/2022\/08\/IMG_20220807_025451-1024x576.png 1024w, https:\/\/www.gislxz.com\/wp-content\/uploads\/2022\/08\/IMG_20220807_025451-300x169.png 300w, https:\/\/www.gislxz.com\/wp-content\/uploads\/2022\/08\/IMG_20220807_025451-768x432.png 768w, https:\/\/www.gislxz.com\/wp-content\/uploads\/2022\/08\/IMG_20220807_025451-1536x864.png 1536w, https:\/\/www.gislxz.com\/wp-content\/uploads\/2022\/08\/IMG_20220807_025451-2048x1152.png 2048w\" sizes=\"auto, (max-width: 1024px) 100vw, 1024px\" \/><figcaption>Lycoris\u66f4\u65b0\u5566<\/figcaption><\/figure><\/div>\n","protected":false},"excerpt":{"rendered":"<p>paddle\u548cpytorch\u5206\u522b\u5b9e\u73b0LeNet\u5bf9\u624b\u5199\u6570\u5b57\u8fdb\u884c\u8bc6\u522b<\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"_jetpack_memberships_contains_paid_content":false,"footnotes":""},"categories":[21],"tags":[],"class_list":["post-403","post","type-post","status-publish","format-standard","hentry","category-21"],"jetpack_featured_media_url":"","jetpack_sharing_enabled":true,"_links":{"self":[{"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/posts\/403","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/comments?post=403"}],"version-history":[{"count":3,"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/posts\/403\/revisions"}],"predecessor-version":[{"id":415,"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/posts\/403\/revisions\/415"}],"wp:attachment":[{"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/media?parent=403"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/categories?post=403"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/tags?post=403"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}