{"id":392,"date":"2022-08-04T13:08:27","date_gmt":"2022-08-04T05:08:27","guid":{"rendered":"http:\/\/www.gislxz.top\/?p=392"},"modified":"2022-08-04T13:08:27","modified_gmt":"2022-08-04T05:08:27","slug":"%e6%b7%b1%e5%ba%a6%e5%ad%a6%e4%b9%a0%e7%ac%94%e8%ae%b0%ef%bc%887%ef%bc%89","status":"publish","type":"post","link":"https:\/\/www.gislxz.com\/index.php\/2022\/08\/04\/%e6%b7%b1%e5%ba%a6%e5%ad%a6%e4%b9%a0%e7%ac%94%e8%ae%b0%ef%bc%887%ef%bc%89\/","title":{"rendered":"\u6df1\u5ea6\u5b66\u4e60\u7b14\u8bb0\uff087\uff09"},"content":{"rendered":"\n<h2 class=\"wp-block-heading\">\u7ee7\u7eed\u8ddf\u7740paddle\u7684\u5b98\u65b9\u6559\u7a0b\uff0c\u8fd9\u4e00\u7ae0\u8bb2\u7684\u662f\u4f18\u5316\u7b97\u6cd5--<a href=\"https:\/\/aistudio.baidu.com\/aistudio\/projectdetail\/1599221\" target=\"_blank\"  rel=\"nofollow\" >\u624b\u5199\u6570\u5b57\u8bc6\u522b\u4e4b\u4f18\u5316\u7b97\u6cd5<\/a><\/h2>\n\n\n\n<p>\u8fd9\u4e00\u7ae0\u6ca1\u6709\u592a\u591a\u7684\u4ee3\u7801\u90e8\u5206\uff0c\u4e3b\u8981\u662f\u539f\u7406\u7684\u8bb2\u89e3\uff0c\u8df3\u8fc7<\/p>\n\n\n\n<h2 class=\"wp-block-heading\">paddle\u5b98\u65b9\u6559\u7a0b\u4e0b\u4e00\u7ae0\u8bb2\u7684\u662f<a href=\"https:\/\/aistudio.baidu.com\/aistudio\/projectdetail\/1599989\" target=\"_blank\"  rel=\"nofollow\" >\u8d44\u6e90\u914d\u7f6e<\/a>\uff0c\u4e5f\u5c31\u662f\u4f7f\u7528GPU\u8bad\u7ec3\u548c\u591a\u5361\u8bad\u7ec3\u3002<\/h2>\n\n\n\n<p>\u4f7f\u7528GPU\u5355\u5361\u8bad\u7ec3\u5f88\u7b80\u5355\uff0c\u52a0\u4e00\u53e5use_gpu = True\u5c31\u5b8c\u4e8b\u4e86<\/p>\n\n\n\n<p>torch\u4e2d\u4f7f\u7528gpu\u7a0d\u5fae\u590d\u6742\u4e00\u70b9\uff0cmodel\uff0cdata\u548closs\u90fd\u8981\u5728\u540e\u9762\u52a0cuda()\u653e\u5165\u663e\u5b58<\/p>\n\n\n\n<p>\u6253\u5370\u7684\u65f6\u5019\u8fd8\u8981\u52a0cpu()\u56de\u5230\u5185\u5b58<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>from torch.nn import Conv2d,MaxPool2d,Linear\n#\u6570\u636e\u5904\u7406\u90e8\u5206\u4e4b\u540e\u7684\u4ee3\u7801\uff0c\u6570\u636e\u8bfb\u53d6\u7684\u90e8\u5206\u8c03\u7528load_data\u51fd\u6570\n# \u5b9a\u4e49\u7f51\u7edc\u7ed3\u6784\uff0c\u540c\u4e0a\u4e00\u8282\u6240\u4f7f\u7528\u7684\u7f51\u7edc\u7ed3\u6784\nclass Mnist(nn.Module):\n    def __init__(self):\n        super(Mnist,self).__init__()\n        self.conv1 = Conv2d(in_channels=1,out_channels=20,kernel_size=5,stride=1,padding=2)\n        self.max_pool1 = MaxPool2d(kernel_size=2,stride=2)\n        self.conv2 = Conv2d(in_channels=20,out_channels=20,kernel_size=5,stride=1,padding=2)\n        self.max_pool2 = MaxPool2d(kernel_size=2,stride=2)\n        self.fc = Linear(in_features=980,out_features=10)\n    def forward(self,x):\n        x = self.conv1(x)\n        x = torch.relu(x)\n        x = self.max_pool1(x)\n        x = self.conv2(x)\n        x = torch.relu(x)\n        x = self.max_pool2(x)\n        x = torch.reshape(x,&#91;x.shape&#91;0],-1])\n        x = self.fc(x)\n        x = F.softmax(x,dim=1)\n        return x\n\n# \u8bad\u7ec3\u914d\u7f6e\uff0c\u5e76\u542f\u52a8\u8bad\u7ec3\u8fc7\u7a0b\nmodel = Mnist()\nmodel=model.cuda()\nmodel.train(mode=True)\n#\u8c03\u7528\u52a0\u8f7d\u6570\u636e\u7684\u51fd\u6570\ntrain_loader = load_data('train')\noptimizer = optim.SGD(model.parameters(),lr= 0.01)\n\nEPOCH_NUM = 10\nfor epoch_id in range(EPOCH_NUM):\n    for batch_id, data in enumerate(train_loader()):\n        #\u51c6\u5907\u6570\u636e\uff0c\u53d8\u5f97\u66f4\u52a0\u7b80\u6d01\n        image_data, label_data = data\n        image = torch.tensor(image_data).cuda()\n        label = torch.tensor(label_data).cuda()\n        image = torch.reshape(image,&#91;image.shape&#91;0],1,28,28]) \n        #\u524d\u5411\u8ba1\u7b97\u7684\u8fc7\u7a0b\n        predict = model(image)\n            \n        #\u8ba1\u7b97\u635f\u5931\uff0c\u53d6\u4e00\u4e2a\u6279\u6b21\u6837\u672c\u635f\u5931\u7684\u5e73\u5747\u503c\n        loss = F.cross_entropy(predict, label.squeeze(dim=1)).cuda()\n        avg_loss = torch.mean(loss)\n            \n        #\u6bcf\u8bad\u7ec3\u4e86200\u6279\u6b21\u7684\u6570\u636e\uff0c\u6253\u5370\u4e0b\u5f53\u524dLoss\u7684\u60c5\u51b5\n        if batch_id % 200 == 0:\n            print(\"epoch: {}, batch: {}, loss is: {}\".format(epoch_id, batch_id, avg_loss.cpu().detach().numpy()))\n            \n        #\u540e\u5411\u4f20\u64ad\uff0c\u66f4\u65b0\u53c2\u6570\u7684\u8fc7\u7a0b\n        avg_loss.backward()\n        optimizer.step()\n        model.zero_grad()<\/code><\/pre>\n\n\n\n<p>\u591a\u5361\u8bad\u7ec3\u6ca1\u8fd9\u4e2a\u6761\u4ef6\uff0c\u770b\u770b\u5230\u4e86\u6240\u91cc\u4e0d\u77e5\u9053\u670d\u52a1\u5668\u6709\u6ca1\u6709\u591a\u5361\uff0c\u5230\u65f6\u5019\u7814\u7a76\u7814\u7a76<\/p>\n\n\n\n<h2 class=\"wp-block-heading\">\u8fd9\u7ae0\u5185\u5bb9\u4e5f\u5f88\u5c11\uff0c\u76f4\u63a5\u4e0b\u4e00\u7ae0<a href=\"https:\/\/aistudio.baidu.com\/aistudio\/projectdetail\/1606145\" target=\"_blank\"  rel=\"nofollow\" >\u3010\u624b\u5199\u6570\u5b57\u8bc6\u522b\u3011\u4e4b\u8bad\u7ec3\u8c03\u8bd5\u4e0e\u4f18\u5316<\/a><\/h2>\n\n\n\n<p>\u8bad\u7ec3\u8fc7\u7a0b\u4f18\u5316\u601d\u8def\u4e3b\u8981\u6709\u5982\u4e0b\u4e94\u4e2a\u5173\u952e\u73af\u8282\uff1a<\/p>\n\n\n\n<p><strong>1. \u8ba1\u7b97\u5206\u7c7b\u51c6\u786e\u7387\uff0c\u89c2\u6d4b\u6a21\u578b\u8bad\u7ec3\u6548\u679c\u3002<\/strong><\/p>\n\n\n\n<p>\u4ea4\u53c9\u71b5\u635f\u5931\u51fd\u6570\u53ea\u80fd\u4f5c\u4e3a\u4f18\u5316\u76ee\u6807\uff0c\u65e0\u6cd5\u76f4\u63a5\u51c6\u786e\u8861\u91cf\u6a21\u578b\u7684\u8bad\u7ec3\u6548\u679c\u3002\u51c6\u786e\u7387\u53ef\u4ee5\u76f4\u63a5\u8861\u91cf\u8bad\u7ec3\u6548\u679c\uff0c\u4f46\u7531\u4e8e\u5176\u79bb\u6563\u6027\u8d28\uff0c\u4e0d\u9002\u5408\u505a\u4e3a\u635f\u5931\u51fd\u6570\u4f18\u5316\u795e\u7ecf\u7f51\u7edc\u3002<\/p>\n\n\n\n<p><strong>2. \u68c0\u67e5\u6a21\u578b\u8bad\u7ec3\u8fc7\u7a0b\uff0c\u8bc6\u522b\u6f5c\u5728\u95ee\u9898\u3002<\/strong><\/p>\n\n\n\n<p>\u5982\u679c\u6a21\u578b\u7684\u635f\u5931\u6216\u8005\u8bc4\u4f30\u6307\u6807\u8868\u73b0\u5f02\u5e38\uff0c\u901a\u5e38\u9700\u8981\u6253\u5370\u6a21\u578b\u6bcf\u4e00\u5c42\u7684\u8f93\u5165\u548c\u8f93\u51fa\u6765\u5b9a\u4f4d\u95ee\u9898\uff0c\u5206\u6790\u6bcf\u4e00\u5c42\u7684\u5185\u5bb9\u6765\u83b7\u53d6\u9519\u8bef\u7684\u539f\u56e0\u3002<\/p>\n\n\n\n<p><strong>3. \u52a0\u5165\u6821\u9a8c\u6216\u6d4b\u8bd5\uff0c\u66f4\u597d\u8bc4\u4ef7\u6a21\u578b\u6548\u679c\u3002<\/strong><\/p>\n\n\n\n<p>\u7406\u60f3\u7684\u6a21\u578b\u8bad\u7ec3\u7ed3\u679c\u662f\u5728\u8bad\u7ec3\u96c6\u548c\u9a8c\u8bc1\u96c6\u4e0a\u5747\u6709\u8f83\u9ad8\u7684\u51c6\u786e\u7387\uff0c\u5982\u679c\u8bad\u7ec3\u96c6\u7684\u51c6\u786e\u7387\u4f4e\u4e8e\u9a8c\u8bc1\u96c6\uff0c\u8bf4\u660e\u7f51\u7edc\u8bad\u7ec3\u7a0b\u5ea6\u4e0d\u591f\uff1b\u5982\u679c\u8bad\u7ec3\u96c6\u7684\u51c6\u786e\u7387\u9ad8\u4e8e\u9a8c\u8bc1\u96c6\uff0c\u53ef\u80fd\u662f\u53d1\u751f\u4e86\u8fc7\u62df\u5408\u73b0\u8c61\u3002\u901a\u8fc7\u5728\u4f18\u5316\u76ee\u6807\u4e2d\u52a0\u5165\u6b63\u5219\u5316\u9879\u7684\u529e\u6cd5\uff0c\u89e3\u51b3\u8fc7\u62df\u5408\u7684\u95ee\u9898\u3002<\/p>\n\n\n\n<p><strong>4. \u52a0\u5165\u6b63\u5219\u5316\u9879\uff0c\u907f\u514d\u6a21\u578b\u8fc7\u62df\u5408\u3002<\/strong><\/p>\n\n\n\n<p>\u98de\u6868\u6846\u67b6\u652f\u6301\u4e3a\u6574\u4f53\u53c2\u6570\u52a0\u5165\u6b63\u5219\u5316\u9879\uff0c\u8fd9\u662f\u901a\u5e38\u7684\u505a\u6cd5\u3002\u6b64\u5916\uff0c\u98de\u6868\u6846\u67b6\u4e5f\u652f\u6301\u4e3a\u67d0\u4e00\u5c42\u6216\u67d0\u4e00\u90e8\u5206\u7684\u7f51\u7edc\u5355\u72ec\u52a0\u5165\u6b63\u5219\u5316\u9879\uff0c\u4ee5\u8fbe\u5230\u7cbe\u7ec6\u8c03\u6574\u53c2\u6570\u8bad\u7ec3\u7684\u6548\u679c\u3002<\/p>\n\n\n\n<p><strong>5. \u53ef\u89c6\u5316\u5206\u6790\u3002<\/strong><\/p>\n\n\n\n<p>\u7528\u6237\u4e0d\u4ec5\u53ef\u4ee5\u901a\u8fc7\u6253\u5370\u6216\u4f7f\u7528matplotlib\u5e93\u4f5c\u56fe\uff0c\u98de\u6868\u8fd8\u63d0\u4f9b\u4e86\u66f4\u4e13\u4e1a\u7684\u53ef\u89c6\u5316\u5206\u6790\u5de5\u5177VisualDL\uff0c\u63d0\u4f9b\u4fbf\u6377\u7684\u53ef\u89c6\u5316\u5206\u6790\u65b9\u6cd5\u3002<\/p>\n\n\n\n<h2 class=\"wp-block-heading\">\u8ba1\u7b97\u51c6\u786e\u7387\uff08Accuracy\uff09<\/h2>\n\n\n\n<p>paddle\u672c\u8eab\u63d0\u4f9b\u7684\u51c6\u786e\u7387api\u53ef\u76f4\u63a5\u8ba1\u7b97\u51c6\u786e\u7387<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code># \u52a0\u8f7d\u76f8\u5173\u5e93\r\nimport os\r\nimport random\r\nimport paddle\r\nimport numpy as np\r\nfrom PIL import Image\r\nimport gzip\r\nimport json\r\n\r\n\r\n# \u5b9a\u4e49\u6570\u636e\u96c6\u8bfb\u53d6\u5668\r\ndef load_data(mode='train'):\r\n\r\n    # \u8bfb\u53d6\u6570\u636e\u6587\u4ef6\r\n    datafile = '.\/work\/mnist.json.gz'\r\n    print('loading mnist dataset from {} ......'.format(datafile))\r\n    data = json.load(gzip.open(datafile))\r\n    # \u8bfb\u53d6\u6570\u636e\u96c6\u4e2d\u7684\u8bad\u7ec3\u96c6\uff0c\u9a8c\u8bc1\u96c6\u548c\u6d4b\u8bd5\u96c6\r\n    train_set, val_set, eval_set = data\r\n\r\n    # \u6570\u636e\u96c6\u76f8\u5173\u53c2\u6570\uff0c\u56fe\u7247\u9ad8\u5ea6IMG_ROWS, \u56fe\u7247\u5bbd\u5ea6IMG_COLS\r\n    IMG_ROWS = 28\r\n    IMG_COLS = 28\r\n    # \u6839\u636e\u8f93\u5165mode\u53c2\u6570\u51b3\u5b9a\u4f7f\u7528\u8bad\u7ec3\u96c6\uff0c\u9a8c\u8bc1\u96c6\u8fd8\u662f\u6d4b\u8bd5\r\n    if mode == 'train':\r\n        imgs = train_set&#91;0]\r\n        labels = train_set&#91;1]\r\n    elif mode == 'valid':\r\n        imgs = val_set&#91;0]\r\n        labels = val_set&#91;1]\r\n    elif mode == 'eval':\r\n        imgs = eval_set&#91;0]\r\n        labels = eval_set&#91;1]\r\n    # \u83b7\u5f97\u6240\u6709\u56fe\u50cf\u7684\u6570\u91cf\r\n    imgs_length = len(imgs)\r\n    # \u9a8c\u8bc1\u56fe\u50cf\u6570\u91cf\u548c\u6807\u7b7e\u6570\u91cf\u662f\u5426\u4e00\u81f4\r\n    assert len(imgs) == len(labels), \\\r\n          \"length of train_imgs({}) should be the same as train_labels({})\".format(\r\n                  len(imgs), len(labels))\r\n\r\n    index_list = list(range(imgs_length))\r\n\r\n    # \u8bfb\u5165\u6570\u636e\u65f6\u7528\u5230\u7684batchsize\r\n    BATCHSIZE = 100\r\n\r\n    # \u5b9a\u4e49\u6570\u636e\u751f\u6210\u5668\r\n    def data_generator():\r\n        # \u8bad\u7ec3\u6a21\u5f0f\u4e0b\uff0c\u6253\u4e71\u8bad\u7ec3\u6570\u636e\r\n        if mode == 'train':\r\n            random.shuffle(index_list)\r\n        imgs_list = &#91;]\r\n        labels_list = &#91;]\r\n        # \u6309\u7167\u7d22\u5f15\u8bfb\u53d6\u6570\u636e\r\n        for i in index_list:\r\n            # \u8bfb\u53d6\u56fe\u50cf\u548c\u6807\u7b7e\uff0c\u8f6c\u6362\u5176\u5c3a\u5bf8\u548c\u7c7b\u578b\r\n            img = np.reshape(imgs&#91;i], &#91;1, IMG_ROWS, IMG_COLS]).astype('float32')\r\n            label = np.reshape(labels&#91;i], &#91;1]).astype('int64')\r\n            imgs_list.append(img) \r\n            labels_list.append(label)\r\n            # \u5982\u679c\u5f53\u524d\u6570\u636e\u7f13\u5b58\u8fbe\u5230\u4e86batch size\uff0c\u5c31\u8fd4\u56de\u4e00\u4e2a\u6279\u6b21\u6570\u636e\r\n            if len(imgs_list) == BATCHSIZE:\r\n                yield np.array(imgs_list), np.array(labels_list)\r\n                # \u6e05\u7a7a\u6570\u636e\u7f13\u5b58\u5217\u8868\r\n                imgs_list = &#91;]\r\n                labels_list = &#91;]\r\n\r\n        # \u5982\u679c\u5269\u4f59\u6570\u636e\u7684\u6570\u76ee\u5c0f\u4e8eBATCHSIZE\uff0c\r\n        # \u5219\u5269\u4f59\u6570\u636e\u4e00\u8d77\u6784\u6210\u4e00\u4e2a\u5927\u5c0f\u4e3alen(imgs_list)\u7684mini-batch\r\n        if len(imgs_list) > 0:\r\n            yield np.array(imgs_list), np.array(labels_list)\r\n\r\n    return data_generator\n\n# \u5b9a\u4e49\u6a21\u578b\u7ed3\u6784\r\nimport paddle.nn.functional as F\r\nfrom paddle.nn import Conv2D, MaxPool2D, Linear\r\n\r\n# \u591a\u5c42\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\u5b9e\u73b0\r\nclass MNIST(paddle.nn.Layer):\r\n     def __init__(self):\r\n         super(MNIST, self).__init__()\r\n         \r\n         # \u5b9a\u4e49\u5377\u79ef\u5c42\uff0c\u8f93\u51fa\u7279\u5f81\u901a\u9053out_channels\u8bbe\u7f6e\u4e3a20\uff0c\u5377\u79ef\u6838\u7684\u5927\u5c0fkernel_size\u4e3a5\uff0c\u5377\u79ef\u6b65\u957fstride=1\uff0cpadding=2\r\n         self.conv1 = Conv2D(in_channels=1, out_channels=20, kernel_size=5, stride=1, padding=2)\r\n         # \u5b9a\u4e49\u6c60\u5316\u5c42\uff0c\u6c60\u5316\u6838\u7684\u5927\u5c0fkernel_size\u4e3a2\uff0c\u6c60\u5316\u6b65\u957f\u4e3a2\r\n         self.max_pool1 = MaxPool2D(kernel_size=2, stride=2)\r\n         # \u5b9a\u4e49\u5377\u79ef\u5c42\uff0c\u8f93\u51fa\u7279\u5f81\u901a\u9053out_channels\u8bbe\u7f6e\u4e3a20\uff0c\u5377\u79ef\u6838\u7684\u5927\u5c0fkernel_size\u4e3a5\uff0c\u5377\u79ef\u6b65\u957fstride=1\uff0cpadding=2\r\n         self.conv2 = Conv2D(in_channels=20, out_channels=20, kernel_size=5, stride=1, padding=2)\r\n         # \u5b9a\u4e49\u6c60\u5316\u5c42\uff0c\u6c60\u5316\u6838\u7684\u5927\u5c0fkernel_size\u4e3a2\uff0c\u6c60\u5316\u6b65\u957f\u4e3a2\r\n         self.max_pool2 = MaxPool2D(kernel_size=2, stride=2)\r\n         # \u5b9a\u4e49\u4e00\u5c42\u5168\u8fde\u63a5\u5c42\uff0c\u8f93\u51fa\u7ef4\u5ea6\u662f10\r\n         self.fc = Linear(in_features=980, out_features=10)\r\n         \r\n   # \u5b9a\u4e49\u7f51\u7edc\u524d\u5411\u8ba1\u7b97\u8fc7\u7a0b\uff0c\u5377\u79ef\u540e\u7d27\u63a5\u7740\u4f7f\u7528\u6c60\u5316\u5c42\uff0c\u6700\u540e\u4f7f\u7528\u5168\u8fde\u63a5\u5c42\u8ba1\u7b97\u6700\u7ec8\u8f93\u51fa\r\n   # \u5377\u79ef\u5c42\u6fc0\u6d3b\u51fd\u6570\u4f7f\u7528Relu\uff0c\u5168\u8fde\u63a5\u5c42\u6fc0\u6d3b\u51fd\u6570\u4f7f\u7528softmax\r\n     def forward(self, inputs, label):\r\n         x = self.conv1(inputs)\r\n         x = F.sigmoid(x)\r\n         x = self.max_pool1(x)\r\n         x = self.conv2(x)\r\n         x = F.sigmoid(x)\r\n         x = self.max_pool2(x)\r\n         x = paddle.reshape(x, &#91;x.shape&#91;0], 980])\r\n         x = self.fc(x)\r\n         x = F.softmax(x)\r\n         if label is not None:\r\n             acc = paddle.metric.accuracy(input=x, label=label)\r\n             return x, acc\r\n         else:\r\n             return x\r\n\r\n#\u8c03\u7528\u52a0\u8f7d\u6570\u636e\u7684\u51fd\u6570\r\ntrain_loader = load_data('train')\r\n    \r\n#\u5728\u4f7f\u7528GPU\u673a\u5668\u65f6\uff0c\u53ef\u4ee5\u5c06use_gpu\u53d8\u91cf\u8bbe\u7f6e\u6210True\r\nuse_gpu = True\r\npaddle.set_device('gpu:0') if use_gpu else paddle.set_device('cpu')\r\n\r\n#\u4ec5\u4f18\u5316\u7b97\u6cd5\u7684\u8bbe\u7f6e\u6709\u6240\u5dee\u522b\r\ndef train(model):\r\n    model = MNIST()\r\n    model.train()\r\n    \r\n    #\u56db\u79cd\u4f18\u5316\u7b97\u6cd5\u7684\u8bbe\u7f6e\u65b9\u6848\uff0c\u53ef\u4ee5\u9010\u4e00\u5c1d\u8bd5\u6548\u679c\r\n    # opt = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())\r\n    # opt = paddle.optimizer.Momentum(learning_rate=0.01, momentum=0.9, parameters=model.parameters())\r\n    # opt = paddle.optimizer.Adagrad(learning_rate=0.01, parameters=model.parameters())\r\n    opt = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())\r\n    \r\n    EPOCH_NUM = 5\r\n    for epoch_id in range(EPOCH_NUM):\r\n        for batch_id, data in enumerate(train_loader()):\r\n            #\u51c6\u5907\u6570\u636e\r\n            images, labels = data\r\n            images = paddle.to_tensor(images)\r\n            labels = paddle.to_tensor(labels)\r\n            \r\n            #\u524d\u5411\u8ba1\u7b97\u7684\u8fc7\u7a0b\r\n            predicts, acc = model(images, labels)\r\n            \r\n            #\u8ba1\u7b97\u635f\u5931\uff0c\u53d6\u4e00\u4e2a\u6279\u6b21\u6837\u672c\u635f\u5931\u7684\u5e73\u5747\u503c\r\n            loss = F.cross_entropy(predicts, labels)\r\n            avg_loss = paddle.mean(loss)\r\n            \r\n            #\u6bcf\u8bad\u7ec3\u4e86100\u6279\u6b21\u7684\u6570\u636e\uff0c\u6253\u5370\u4e0b\u5f53\u524dLoss\u7684\u60c5\u51b5\r\n            if batch_id % 200 == 0:\r\n                print(\"epoch: {}, batch: {}, loss is: {}, acc is {}\".format(epoch_id, batch_id, avg_loss.numpy(), acc.numpy()))\r\n                \r\n            #\u540e\u5411\u4f20\u64ad\uff0c\u66f4\u65b0\u53c2\u6570\uff0c\u6d88\u9664\u68af\u5ea6\u7684\u8fc7\u7a0b\r\n            avg_loss.backward()\r\n            opt.step()\r\n            opt.clear_grad()\r\n\r\n    #\u4fdd\u5b58\u6a21\u578b\u53c2\u6570\r\n    paddle.save(model.state_dict(), 'mnist.pdparams')\r\n    \r\n#\u521b\u5efa\u6a21\u578b    \r\nmodel = MNIST()\r\n#\u542f\u52a8\u8bad\u7ec3\u8fc7\u7a0b\r\ntrain(model)<\/code><\/pre>\n\n\n\n<p>torch\u4e2d\u597d\u4f3c\u6ca1\u6709\u8fd9\u6837\u5305\u88c5\u597d\u7684api\uff0c\u4e0d\u8fc7\u539f\u7406\u4e0d\u590d\u6742\u6211\u4eec\u53ef\u4ee5\u81ea\u5df1\u5199\u3002<\/p>\n\n\n\n<p>\u8fd4\u56de\u7684predict\u7684size\u662f\u3010batchsize\uff0c10\u3011\uff0c\u4e5f\u5c31\u662f\u4e00\u6b21\u8fd4\u56de\u4e00\u4e2abatchsize\u7684\u7ed3\u679c\uff0c\u6bcf\u4e2a\u7ed3\u679c\u662f\u5341\u4e2a\u6570\u5b57\u7684list\uff0c\u662fsoftmax\u540e\u6bcf\u4e2alabel\u5bf9\u5e94\u7684\u53ef\u80fd\u6027\uff0c\u6211\u4eec\u53ea\u8981\u627e\u51fa\u6570\u503c\u6700\u5927\u7684\u90a3\u4e2a\u8bb0\u5f55\u4e0b\u4ed6\u7684label\u5c31\u662f\u9884\u6d4b\u7684\u7ed3\u679c\uff0c\u7136\u540e\u53bb\u548c\u771f\u5b9elabel\u5bf9\u6bd4\u4e00\u4e0b\u5c31\u884c\uff0c\u8ba1\u7b97\u4e00\u4e0b\u4e00\u4e2abatchsize\u4e2d\u9884\u6d4b\u5bf9\u4e86\u591a\u5c11\u4e2a\uff0c\u7136\u540e\u9664\u4ee5batchsize\uff0c\u975e\u5e38\u7b80\u5355\u7684\u5c0f\u5b66\u6570\u5b66\u3002<\/p>\n\n\n\n<p>torch\u4e2d\u63d0\u4f9b\u4e86max\u51fd\u6570\u53ef\u4ee5\u76f4\u63a5\u8fd4\u56de\u6700\u5927\u503c\u7684\u503c\u548c\u6807\u7b7e\uff0c\u8fd9\u5c31\u5f88\u65b9\u4fbf\u3002\u8fd9\u4e2a\u51fd\u6570\u8fd8\u6709\u4e00\u4e2a\u53c2\u6570\u662f\u7ef4\u6570\uff0c\u5c31\u662f\u4ece\u54ea\u4e2a\u7ef4\u6570\u53bb\u8ba1\u7b97\uff0c\u4e8e\u662f\u6211\u4eec\u53ea\u7528torch.max(predict,1)[1]\u5c31\u53ef\u4ee5\u5f97\u5230\u4e00\u7ef4\u7684\u9884\u6d4b\u7ed3\u679c\uff0c\u4e0e\u771f\u5b9e\u6807\u7b7e\u6bd4\u5bf9\u4e00\u4e0b\u5c31\u662f\u5f53\u524dbatch\u7684\u51c6\u786e\u7387\u3002paddle\u8fd9\u4e2aapi\u8fd4\u56de\u7684\u5e94\u8be5\u5c31\u662f\u5f53\u524d\u6279\u6b21\u7684acc\u3002\u6211\u4eec\u8bbe\u7f6e\u7684\u662f200\u4e2abatch\u8fd4\u56de\u4e00\u6b21loss\u548cacc\uff0c\u5b9e\u9645\u4e0a\u6211\u4eec\u4e5f\u53ef\u4ee5\u628aacc\u505a\u4e00\u4e2a\u5e73\u5747\uff0c\u8fd4\u56de\u8fd9200\u4e2abatch\u7684\u5e73\u5747acc\uff0c\u8fd9\u6837\u66f4\u79d1\u5b66\u4e00\u70b9\u3002<\/p>\n\n\n\n<p>\u8fd9\u91cc\u601d\u8003\u4e86\u4e00\u4e0b\uff0c\u5176\u5b9epaddle\u8fd9\u4e2aapi\u653e\u5728\u4e86model\u91cc\u7684forward\u8fc7\u7a0b\uff0c\u610f\u5473\u7740\u6bcf\u4e2abatch\u90fd\u8ba1\u7b97\u4e86acc\uff0c\u4f46\u6211\u4eec\u53ea\u5728\u6bcf\u9694200\u4e2abatch\u624d\u8f93\u51fa\u4e00\u6b21acc\uff0closs\u7528\u4e8e\u68af\u5ea6\u4e0b\u964d\u6ca1\u529e\u6cd5\u6bcf\u4e2abatch\u90fd\u8981\u7b97\uff0c\u4f46\u662facc\u5982\u679c\u53ea\u662f\u8981\u8f93\u51fa\u6bcf200batch\u65f6\u5f53\u524dbatch\u7684\u90a3\u4e2aacc\u90a3\u5b8c\u5168\u6ca1\u5fc5\u8981\u6bcf\u4e2abatch\u90fd\u7b97acc\uff0c\u867d\u7136acc\u6bd4loss\u8ba1\u7b97\u7b80\u5355\u5f88\u591a\uff0c\u4f46\u80fd\u5199\u5728if\u5224\u65ad\u540e\u5728\u8ba1\u7b97\u8fd8\u662f\u80fd\u7701\u4e0b\u4e00\u70b9\u6027\u80fd\u7684\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>from torch.nn import Conv2d,MaxPool2d,Linear\r\n#\u6570\u636e\u5904\u7406\u90e8\u5206\u4e4b\u540e\u7684\u4ee3\u7801\uff0c\u6570\u636e\u8bfb\u53d6\u7684\u90e8\u5206\u8c03\u7528load_data\u51fd\u6570\r\n# \u5b9a\u4e49\u7f51\u7edc\u7ed3\u6784\uff0c\u540c\u4e0a\u4e00\u8282\u6240\u4f7f\u7528\u7684\u7f51\u7edc\u7ed3\u6784\r\nclass Mnist(nn.Module):\r\n    def __init__(self):\r\n        super(Mnist,self).__init__()\r\n        self.conv1 = Conv2d(in_channels=1,out_channels=20,kernel_size=5,stride=1,padding=2)\r\n        self.max_pool1 = MaxPool2d(kernel_size=2,stride=2)\r\n        self.conv2 = Conv2d(in_channels=20,out_channels=20,kernel_size=5,stride=1,padding=2)\r\n        self.max_pool2 = MaxPool2d(kernel_size=2,stride=2)\r\n        self.fc = Linear(in_features=980,out_features=10)\r\n    def forward(self,x):\r\n        x = self.conv1(x)\r\n        x = torch.relu(x)\r\n        x = self.max_pool1(x)\r\n        x = self.conv2(x)\r\n        x = torch.relu(x)\r\n        x = self.max_pool2(x)\r\n        x = torch.reshape(x,&#91;x.shape&#91;0],-1])\r\n        x = self.fc(x)\r\n        x = F.softmax(x,dim=1)\r\n        return x\r\n\r\n# \u8bad\u7ec3\u914d\u7f6e\uff0c\u5e76\u542f\u52a8\u8bad\u7ec3\u8fc7\u7a0b\r\nmodel = Mnist()\r\nmodel=model.cuda()\r\nmodel.train(mode=True)\r\n#\u8c03\u7528\u52a0\u8f7d\u6570\u636e\u7684\u51fd\u6570\r\ntrain_loader = load_data('train')\r\noptimizer = optim.Adam(model.parameters(),lr= 0.001)\r\n\r\nBATCHSIZE=100\r\nEPOCH_NUM = 10\r\nfor epoch_id in range(EPOCH_NUM):\r\n    for batch_id, data in enumerate(train_loader()):\r\n        #\u51c6\u5907\u6570\u636e\uff0c\u53d8\u5f97\u66f4\u52a0\u7b80\u6d01\r\n        image_data, label_data = data\r\n        image = torch.tensor(image_data).cuda()\r\n        label = torch.tensor(label_data).cuda()\r\n        image = torch.reshape(image,&#91;image.shape&#91;0],1,28,28]) \r\n        #\u524d\u5411\u8ba1\u7b97\u7684\u8fc7\u7a0b\r\n        predict = model(image)\r\n            \r\n        #\u8ba1\u7b97\u635f\u5931\uff0c\u53d6\u4e00\u4e2a\u6279\u6b21\u6837\u672c\u635f\u5931\u7684\u5e73\u5747\u503c\r\n        loss = F.cross_entropy(predict, label.squeeze(dim=1)).cuda()\r\n        avg_loss = torch.mean(loss) \r\n        #\u6bcf\u8bad\u7ec3\u4e86200\u6279\u6b21\u7684\u6570\u636e\uff0c\u6253\u5370\u4e0b\u5f53\u524dLoss\u7684\u60c5\u51b5\r\n        if batch_id % 200 == 0:\r\n            correct=0\r\n            predict_label = torch.max(predict,1)&#91;1]\r\n            correct += (predict_label == label.squeeze(dim=1)).sum()\r\n            print(\"epoch: {}, batch: {}, loss is: {}, Accuracy:{:.3f}\".format(epoch_id, batch_id, avg_loss.cpu().detach().numpy(),correct\/BATCHSIZE))\r\n            \r\n        #\u540e\u5411\u4f20\u64ad\uff0c\u66f4\u65b0\u53c2\u6570\u7684\u8fc7\u7a0b\r\n        avg_loss.backward()\r\n        optimizer.step()\r\n        model.zero_grad()<\/code><\/pre>\n\n\n\n<p>\u4e00\u5f00\u59cb\u5bf9\u6bd4\u4e86\u4e00\u4e0b\uff0c\u53d1\u73b0torch\u7684acc\u6700\u540e\u4e5f\u53ea\u67090.80\u5de6\u53f3\uff0c\u800cpaddle\u80fd\u52300.96\uff0c\u68c0\u67e5\u4e86\u4e00\u4e0b\u53d1\u73b0\u662f\u68af\u5ea6\u4e0b\u964d\u7684\u65b9\u6cd5\u4e0d\u4e00\u6837\uff0c\u628atorch\u7684\u4e5f\u6539\u6210adam\u4e4b\u540e\u6700\u540e\u4e00\u4e2abatch\u7684acc\u80fd\u52301\uff0c\u4e5f\u5c31\u662f\u8fd9\u4e00\u4e2abatch\u91cc100\u4e2a\u6570\u5b57\u5168\u90e8\u8bc6\u522b\u6b63\u786e\u4e86\u3002<\/p>\n\n\n\n<p>\u63a5\u7740\u628a\u4ee3\u7801\u6539\u6210\u8ba1\u7b97\u6bcf200\u4e2abatch\u7684\u5e73\u5747acc<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>from torch.nn import Conv2d,MaxPool2d,Linear\r\n#\u6570\u636e\u5904\u7406\u90e8\u5206\u4e4b\u540e\u7684\u4ee3\u7801\uff0c\u6570\u636e\u8bfb\u53d6\u7684\u90e8\u5206\u8c03\u7528load_data\u51fd\u6570\r\n# \u5b9a\u4e49\u7f51\u7edc\u7ed3\u6784\uff0c\u540c\u4e0a\u4e00\u8282\u6240\u4f7f\u7528\u7684\u7f51\u7edc\u7ed3\u6784\r\nclass Mnist(nn.Module):\r\n    def __init__(self):\r\n        super(Mnist,self).__init__()\r\n        self.conv1 = Conv2d(in_channels=1,out_channels=20,kernel_size=5,stride=1,padding=2)\r\n        self.max_pool1 = MaxPool2d(kernel_size=2,stride=2)\r\n        self.conv2 = Conv2d(in_channels=20,out_channels=20,kernel_size=5,stride=1,padding=2)\r\n        self.max_pool2 = MaxPool2d(kernel_size=2,stride=2)\r\n        self.fc = Linear(in_features=980,out_features=10)\r\n    def forward(self,x):\r\n        x = self.conv1(x)\r\n        x = torch.relu(x)\r\n        x = self.max_pool1(x)\r\n        x = self.conv2(x)\r\n        x = torch.relu(x)\r\n        x = self.max_pool2(x)\r\n        x = torch.reshape(x,&#91;x.shape&#91;0],-1])\r\n        x = self.fc(x)\r\n        x = F.softmax(x,dim=1)\r\n        return x\r\n\r\n# \u8bad\u7ec3\u914d\u7f6e\uff0c\u5e76\u542f\u52a8\u8bad\u7ec3\u8fc7\u7a0b\r\nmodel = Mnist()\r\nmodel=model.cuda()\r\nmodel.train(mode=True)\r\n#\u8c03\u7528\u52a0\u8f7d\u6570\u636e\u7684\u51fd\u6570\r\ntrain_loader = load_data('train')\r\noptimizer = optim.Adam(model.parameters(),lr= 0.001)\r\n\r\nBATCHSIZE=100\r\nEPOCH_NUM = 10\r\nfor epoch_id in range(EPOCH_NUM):\r\n    correct=0\r\n    for batch_id, data in enumerate(train_loader()):\r\n        #\u51c6\u5907\u6570\u636e\uff0c\u53d8\u5f97\u66f4\u52a0\u7b80\u6d01\r\n        image_data, label_data = data\r\n        image = torch.tensor(image_data).cuda()\r\n        label = torch.tensor(label_data).cuda()\r\n        image = torch.reshape(image,&#91;image.shape&#91;0],1,28,28]) \r\n        #\u524d\u5411\u8ba1\u7b97\u7684\u8fc7\u7a0b\r\n        predict = model(image)\r\n        predict_label = torch.max(predict,1)&#91;1]\r\n        correct += (predict_label == label.squeeze(dim=1)).sum()\r\n        #\u8ba1\u7b97\u635f\u5931\uff0c\u53d6\u4e00\u4e2a\u6279\u6b21\u6837\u672c\u635f\u5931\u7684\u5e73\u5747\u503c\r\n        loss = F.cross_entropy(predict, label.squeeze(dim=1)).cuda()\r\n        avg_loss = torch.mean(loss)\r\n        #\u6bcf\u8bad\u7ec3\u4e86200\u6279\u6b21\u7684\u6570\u636e\uff0c\u6253\u5370\u4e0b\u5f53\u524dLoss\u7684\u60c5\u51b5\r\n        if batch_id % 200 == 0:\r\n            print(\"epoch: {}, batch: {}, loss is: {}, Accuracy:{:.3f}\".format(epoch_id, batch_id, avg_loss.cpu().detach().numpy(),\\                                                                             correct\/(BATCHSIZE*200)))\r\n            correct=0\r\n        #\u540e\u5411\u4f20\u64ad\uff0c\u66f4\u65b0\u53c2\u6570\u7684\u8fc7\u7a0b\r\n        avg_loss.backward()\r\n        optimizer.step()\r\n        model.zero_grad()<\/code><\/pre>\n\n\n\n<p>\u8fd9\u6837\u53d1\u73b0\u6bcf\u4e2aepoch\u5f00\u59cbbatchid=0\u8f93\u51fa\u7684\u65f6\u5019\u53ea\u7d2f\u79ef\u4e86\u7b2c\u4e00\u4e2abatch\u7684correct\uff0c\u4f46\u53c8\u5f53\u6210\u4e86200\u4e2abatch\u7684correct\u603b\u6570\u9664\u4e86200\uff0c\u51fa\u4e86bug\uff0c\u4e8b\u5b9e\u4e0a\u6211\u4eec\u4e0d\u8be5\u8f93\u51fabatch_id=0\u65f6\u5019\u7684\u6570\u636e\uff0c\u800c\u5e94\u8be5\u5728\u6700\u540e\u4e00\u7ec4batch\u4e0d\u8db3200\u65f6\u7684\u53c2\u6570\u3002\u6700\u597d\u7684\u89e3\u51b3\u529e\u6cd5\u662f\u628a\u5934\u5c3e\u90fd\u4e0d\u8f93\u51fa\uff0c\u8fd9\u6837\u6700\u7b80\u5355\u3002<\/p>\n\n\n\n<p>if batch_id!=0 and batch_id % 200 == 0:\u8fd9\u6837\u5c31\u884c\u3002\u8981\u662f\u5acc\u4e00\u4e2aepoch\u5185\u8f93\u51fa\u4e24\u4e2a\u4e2d\u95f4\u7ed3\u679c\u592a\u5c11\uff0c\u5c31\u628a200\u6539\u6210150\u6216\u8005100\u5c31\u597d\u3002<\/p>\n\n\n\n<figure class=\"wp-block-image size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"400\" height=\"400\" src=\"http:\/\/www.gislxz.com\/wp-content\/uploads\/2022\/07\/645D54C940601872D15F32B384940E70.gif\" alt=\"\" class=\"wp-image-355\"\/><\/figure>\n\n\n\n<h2 class=\"wp-block-heading\">\u68c0\u67e5\u6a21\u578b\u8bad\u7ec3\u8fc7\u7a0b\uff0c\u8bc6\u522b\u6f5c\u5728\u8bad\u7ec3\u95ee\u9898<\/h2>\n\n\n\n<p>\u8fd9\u4e00\u90e8\u5206\u8bb2\u7684\u6253\u5370\u7f51\u7edc\u7ed3\u6784\u548c\u5185\u5bb9\u503c\u6765\u68c0\u67e5\u7ed3\u6784\u662f\u5426\u6b63\u786e\uff0c\u8bbe\u8ba1\u662f\u5426\u6709\u95ee\u9898\uff0c\u6709\u6ca1\u6709\u65e0\u6548\u90e8\u5206<\/p>\n\n\n\n<p>\u793a\u4f8b\u4ee3\u7801\u5728\u8bad\u7ec3\u5f00\u59cb\u65f6\u6253\u5370\u51fa\u7f51\u7edc\u7ed3\u6784\uff0c\u5e76\u5728\u6bcf\u4e2aepoch\u7684\u7b2c401\u4e2abatch\u6253\u5370\u6bcf\u4e00\u5c42\u7684\u5185\u5bb9\u503c\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>import paddle.nn.functional as F\r\n# \u5b9a\u4e49\u6a21\u578b\u7ed3\u6784\r\nclass MNIST(paddle.nn.Layer):\r\n     def __init__(self):\r\n         super(MNIST, self).__init__()\r\n         \r\n         # \u5b9a\u4e49\u5377\u79ef\u5c42\uff0c\u8f93\u51fa\u7279\u5f81\u901a\u9053out_channels\u8bbe\u7f6e\u4e3a20\uff0c\u5377\u79ef\u6838\u7684\u5927\u5c0fkernel_size\u4e3a5\uff0c\u5377\u79ef\u6b65\u957fstride=1\uff0cpadding=2\r\n         self.conv1 = Conv2D(in_channels=1, out_channels=20, kernel_size=5, stride=1, padding=2)\r\n         # \u5b9a\u4e49\u6c60\u5316\u5c42\uff0c\u6c60\u5316\u6838\u7684\u5927\u5c0fkernel_size\u4e3a2\uff0c\u6c60\u5316\u6b65\u957f\u4e3a2\r\n         self.max_pool1 = MaxPool2D(kernel_size=2, stride=2)\r\n         # \u5b9a\u4e49\u5377\u79ef\u5c42\uff0c\u8f93\u51fa\u7279\u5f81\u901a\u9053out_channels\u8bbe\u7f6e\u4e3a20\uff0c\u5377\u79ef\u6838\u7684\u5927\u5c0fkernel_size\u4e3a5\uff0c\u5377\u79ef\u6b65\u957fstride=1\uff0cpadding=2\r\n         self.conv2 = Conv2D(in_channels=20, out_channels=20, kernel_size=5, stride=1, padding=2)\r\n         # \u5b9a\u4e49\u6c60\u5316\u5c42\uff0c\u6c60\u5316\u6838\u7684\u5927\u5c0fkernel_size\u4e3a2\uff0c\u6c60\u5316\u6b65\u957f\u4e3a2\r\n         self.max_pool2 = MaxPool2D(kernel_size=2, stride=2)\r\n         # \u5b9a\u4e49\u4e00\u5c42\u5168\u8fde\u63a5\u5c42\uff0c\u8f93\u51fa\u7ef4\u5ea6\u662f10\r\n         self.fc = Linear(in_features=980, out_features=10)\r\n     \r\n     #\u52a0\u5165\u5bf9\u6bcf\u4e00\u5c42\u8f93\u5165\u548c\u8f93\u51fa\u7684\u5c3a\u5bf8\u548c\u6570\u636e\u5185\u5bb9\u7684\u6253\u5370\uff0c\u6839\u636echeck\u53c2\u6570\u51b3\u7b56\u662f\u5426\u6253\u5370\u6bcf\u5c42\u7684\u53c2\u6570\u548c\u8f93\u51fa\u5c3a\u5bf8\r\n     # \u5377\u79ef\u5c42\u6fc0\u6d3b\u51fd\u6570\u4f7f\u7528Relu\uff0c\u5168\u8fde\u63a5\u5c42\u6fc0\u6d3b\u51fd\u6570\u4f7f\u7528softmax\r\n     def forward(self, inputs, label=None, check_shape=False, check_content=False):\r\n         # \u7ed9\u4e0d\u540c\u5c42\u7684\u8f93\u51fa\u4e0d\u540c\u547d\u540d\uff0c\u65b9\u4fbf\u8c03\u8bd5\r\n         outputs1 = self.conv1(inputs)\r\n         outputs2 = F.relu(outputs1)\r\n         outputs3 = self.max_pool1(outputs2)\r\n         outputs4 = self.conv2(outputs3)\r\n         outputs5 = F.relu(outputs4)\r\n         outputs6 = self.max_pool2(outputs5)\r\n         outputs6 = paddle.reshape(outputs6, &#91;outputs6.shape&#91;0], -1])\r\n         outputs7 = self.fc(outputs6)\r\n         outputs8 = F.softmax(outputs7)\r\n         \r\n         # \u9009\u62e9\u662f\u5426\u6253\u5370\u795e\u7ecf\u7f51\u7edc\u6bcf\u5c42\u7684\u53c2\u6570\u5c3a\u5bf8\u548c\u8f93\u51fa\u5c3a\u5bf8\uff0c\u9a8c\u8bc1\u7f51\u7edc\u7ed3\u6784\u662f\u5426\u8bbe\u7f6e\u6b63\u786e\r\n         if check_shape:\r\n             # \u6253\u5370\u6bcf\u5c42\u7f51\u7edc\u8bbe\u7f6e\u7684\u8d85\u53c2\u6570-\u5377\u79ef\u6838\u5c3a\u5bf8\uff0c\u5377\u79ef\u6b65\u957f\uff0c\u5377\u79efpadding\uff0c\u6c60\u5316\u6838\u5c3a\u5bf8\r\n             print(\"\\n########## print network layer's superparams ##############\")\r\n             print(\"conv1-- kernel_size:{}, padding:{}, stride:{}\".format(self.conv1.weight.shape, self.conv1._padding, self.conv1._stride))\r\n             print(\"conv2-- kernel_size:{}, padding:{}, stride:{}\".format(self.conv2.weight.shape, self.conv2._padding, self.conv2._stride))\r\n             #print(\"max_pool1-- kernel_size:{}, padding:{}, stride:{}\".format(self.max_pool1.pool_size, self.max_pool1.pool_stride, self.max_pool1._stride))\r\n             #print(\"max_pool2-- kernel_size:{}, padding:{}, stride:{}\".format(self.max_pool2.weight.shape, self.max_pool2._padding, self.max_pool2._stride))\r\n             print(\"fc-- weight_size:{}, bias_size_{}\".format(self.fc.weight.shape, self.fc.bias.shape))\r\n             \r\n             # \u6253\u5370\u6bcf\u5c42\u7684\u8f93\u51fa\u5c3a\u5bf8\r\n             print(\"\\n########## print shape of features of every layer ###############\")\r\n             print(\"inputs_shape: {}\".format(inputs.shape))\r\n             print(\"outputs1_shape: {}\".format(outputs1.shape))\r\n             print(\"outputs2_shape: {}\".format(outputs2.shape))\r\n             print(\"outputs3_shape: {}\".format(outputs3.shape))\r\n             print(\"outputs4_shape: {}\".format(outputs4.shape))\r\n             print(\"outputs5_shape: {}\".format(outputs5.shape))\r\n             print(\"outputs6_shape: {}\".format(outputs6.shape))\r\n             print(\"outputs7_shape: {}\".format(outputs7.shape))\r\n             print(\"outputs8_shape: {}\".format(outputs8.shape))\r\n             \r\n         # \u9009\u62e9\u662f\u5426\u6253\u5370\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u7684\u53c2\u6570\u548c\u8f93\u51fa\u5185\u5bb9\uff0c\u53ef\u7528\u4e8e\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u7684\u8c03\u8bd5\r\n         if check_content:\r\n            # \u6253\u5370\u5377\u79ef\u5c42\u7684\u53c2\u6570-\u5377\u79ef\u6838\u6743\u91cd\uff0c\u6743\u91cd\u53c2\u6570\u8f83\u591a\uff0c\u6b64\u5904\u53ea\u6253\u5370\u90e8\u5206\u53c2\u6570\r\n             print(\"\\n########## print convolution layer's kernel ###############\")\r\n             print(\"conv1 params -- kernel weights:\", self.conv1.weight&#91;0]&#91;0])\r\n             print(\"conv2 params -- kernel weights:\", self.conv2.weight&#91;0]&#91;0])\r\n\r\n             # \u521b\u5efa\u968f\u673a\u6570\uff0c\u968f\u673a\u6253\u5370\u67d0\u4e00\u4e2a\u901a\u9053\u7684\u8f93\u51fa\u503c\r\n             idx1 = np.random.randint(0, outputs1.shape&#91;1])\r\n             idx2 = np.random.randint(0, outputs4.shape&#91;1])\r\n             # \u6253\u5370\u5377\u79ef-\u6c60\u5316\u540e\u7684\u7ed3\u679c\uff0c\u4ec5\u6253\u5370batch\u4e2d\u7b2c\u4e00\u4e2a\u56fe\u50cf\u5bf9\u5e94\u7684\u7279\u5f81\r\n             print(\"\\nThe {}th channel of conv1 layer: \".format(idx1), outputs1&#91;0]&#91;idx1])\r\n             print(\"The {}th channel of conv2 layer: \".format(idx2), outputs4&#91;0]&#91;idx2])\r\n             print(\"The output of last layer:\", outputs8&#91;0], '\\n')\r\n            \r\n        # \u5982\u679clabel\u4e0d\u662fNone\uff0c\u5219\u8ba1\u7b97\u5206\u7c7b\u7cbe\u5ea6\u5e76\u8fd4\u56de\r\n         if label is not None:\r\n             acc = paddle.metric.accuracy(input=outputs8, label=label)\r\n             return outputs8, acc\r\n         else:\r\n             return outputs8\r\n\r\n#\u5728\u4f7f\u7528GPU\u673a\u5668\u65f6\uff0c\u53ef\u4ee5\u5c06use_gpu\u53d8\u91cf\u8bbe\u7f6e\u6210True\r\nuse_gpu = True\r\npaddle.set_device('gpu:0') if use_gpu else paddle.set_device('cpu')    \r\n\r\ndef train(model):\r\n    model = MNIST()\r\n    model.train()\r\n    \r\n    #\u56db\u79cd\u4f18\u5316\u7b97\u6cd5\u7684\u8bbe\u7f6e\u65b9\u6848\uff0c\u53ef\u4ee5\u9010\u4e00\u5c1d\u8bd5\u6548\u679c\r\n    opt = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())\r\n    # opt = paddle.optimizer.Momentum(learning_rate=0.01, momentum=0.9, parameters=model.parameters())\r\n    # opt = paddle.optimizer.Adagrad(learning_rate=0.01, parameters=model.parameters())\r\n    # opt = paddle.optimizer.Adam(learning_rate=0.01, parameters=model.parameters())\r\n    \r\n    EPOCH_NUM = 1\r\n    for epoch_id in range(EPOCH_NUM):\r\n        for batch_id, data in enumerate(train_loader()):\r\n            #\u51c6\u5907\u6570\u636e\uff0c\u53d8\u5f97\u66f4\u52a0\u7b80\u6d01\r\n            images, labels = data\r\n            images = paddle.to_tensor(images)\r\n            labels = paddle.to_tensor(labels)\r\n            \r\n            #\u524d\u5411\u8ba1\u7b97\u7684\u8fc7\u7a0b\uff0c\u540c\u65f6\u62ff\u5230\u6a21\u578b\u8f93\u51fa\u503c\u548c\u5206\u7c7b\u51c6\u786e\u7387\r\n            if batch_id == 0 and epoch_id==0:\r\n                # \u6253\u5370\u6a21\u578b\u53c2\u6570\u548c\u6bcf\u5c42\u8f93\u51fa\u7684\u5c3a\u5bf8\r\n                predicts, acc = model(images, labels, check_shape=True, check_content=False)\r\n            elif batch_id==401:\r\n                # \u6253\u5370\u6a21\u578b\u53c2\u6570\u548c\u6bcf\u5c42\u8f93\u51fa\u7684\u503c\r\n                predicts, acc = model(images, labels, check_shape=False, check_content=True)\r\n            else:\r\n                predicts, acc = model(images, labels)\r\n            \r\n            #\u8ba1\u7b97\u635f\u5931\uff0c\u53d6\u4e00\u4e2a\u6279\u6b21\u6837\u672c\u635f\u5931\u7684\u5e73\u5747\u503c\r\n            loss = F.cross_entropy(predicts, labels)\r\n            avg_loss = paddle.mean(loss)\r\n            \r\n            #\u6bcf\u8bad\u7ec3\u4e86100\u6279\u6b21\u7684\u6570\u636e\uff0c\u6253\u5370\u4e0b\u5f53\u524dLoss\u7684\u60c5\u51b5\r\n            if batch_id % 200 == 0:\r\n                print(\"epoch: {}, batch: {}, loss is: {}, acc is {}\".format(epoch_id, batch_id, avg_loss.numpy(), acc.numpy()))\r\n            \r\n            #\u540e\u5411\u4f20\u64ad\uff0c\u66f4\u65b0\u53c2\u6570\u7684\u8fc7\u7a0b\r\n            avg_loss.backward()\r\n            opt.step()\r\n            opt.clear_grad()\r\n\r\n    #\u4fdd\u5b58\u6a21\u578b\u53c2\u6570\r\n    paddle.save(model.state_dict(), 'mnist_test.pdparams')\r\n    \r\n#\u521b\u5efa\u6a21\u578b    \r\nmodel = MNIST()\r\n#\u542f\u52a8\u8bad\u7ec3\u8fc7\u7a0b\r\ntrain(model)\r\n\r\nprint(\"Model has been saved.\")<\/code><\/pre>\n\n\n\n<p>torch\u7248\u672c\u4e5f\u540c\u6837\u52a0\u5165\u8f93\u51fa\u7ed3\u6784\u548c\u5185\u5bb9\u503c\u7684\u4ee3\u7801\uff0c\u6ce8\u610ftorch\u7684conv\u7684\u53d8\u91cf\u6ca1\u6709\u4e0b\u5212\u7ebf<\/p>\n\n\n\n<p>paddle\u662fself.conv1._padding\u800ctorch\u7684\u662fself.conv1.padding<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>from torch.nn import Conv2d,MaxPool2d,Linear\r\n#\u6570\u636e\u5904\u7406\u90e8\u5206\u4e4b\u540e\u7684\u4ee3\u7801\uff0c\u6570\u636e\u8bfb\u53d6\u7684\u90e8\u5206\u8c03\u7528load_data\u51fd\u6570\r\n# \u5b9a\u4e49\u7f51\u7edc\u7ed3\u6784\uff0c\u540c\u4e0a\u4e00\u8282\u6240\u4f7f\u7528\u7684\u7f51\u7edc\u7ed3\u6784\r\nclass Mnist(nn.Module):\r\n    def __init__(self):\r\n        super(Mnist,self).__init__()\r\n        self.conv1 = Conv2d(in_channels=1,out_channels=20,kernel_size=5,stride=1,padding=2)\r\n        self.max_pool1 = MaxPool2d(kernel_size=2,stride=2)\r\n        self.conv2 = Conv2d(in_channels=20,out_channels=20,kernel_size=5,stride=1,padding=2)\r\n        self.max_pool2 = MaxPool2d(kernel_size=2,stride=2)\r\n        self.fc = Linear(in_features=980,out_features=10)\r\n    def forward(self,inputs,check_shape=False,check_content=False):\r\n        outputs1 = self.conv1(inputs)\r\n        outputs2 = torch.relu(outputs1)\r\n        outputs3 = self.max_pool1(outputs2)\r\n        outputs4 = self.conv2(outputs3)\r\n        outputs5 = torch.relu(outputs4)\r\n        outputs6 = self.max_pool2(outputs5)\r\n        outputs6 = torch.reshape(outputs6,&#91;outputs6.shape&#91;0],-1])\r\n        outputs7 = self.fc(outputs6)\r\n        outputs8 = F.softmax(outputs7,dim=1)\r\n        if check_shape:\r\n            # \u6253\u5370\u6bcf\u5c42\u7f51\u7edc\u8bbe\u7f6e\u7684\u8d85\u53c2\u6570-\u5377\u79ef\u6838\u5c3a\u5bf8\uff0c\u5377\u79ef\u6b65\u957f\uff0c\u5377\u79efpadding\uff0c\u6c60\u5316\u6838\u5c3a\u5bf8\r\n            print(\"\\n########## print network layer's superparams ##############\")\r\n            print(\"conv1-- kernel_size:{}, padding:{}, stride:{}\".format(self.conv1.weight.shape, self.conv1.padding, self.conv1.stride))\r\n            print(\"conv2-- kernel_size:{}, padding:{}, stride:{}\".format(self.conv2.weight.shape, self.conv2.padding, self.conv2.stride))\r\n            #print(\"max_pool1-- kernel_size:{}, padding:{}, stride:{}\".format(self.max_pool1.pool_size, self.max_pool1.pool_stride, self.max_pool1._stride))\r\n            #print(\"max_pool2-- kernel_size:{}, padding:{}, stride:{}\".format(self.max_pool2.weight.shape, self.max_pool2._padding, self.max_pool2._stride))\r\n            print(\"fc-- weight_size:{}, bias_size_{}\".format(self.fc.weight.shape, self.fc.bias.shape))\r\n             \r\n            # \u6253\u5370\u6bcf\u5c42\u7684\u8f93\u51fa\u5c3a\u5bf8\r\n            print(\"\\n########## print shape of features of every layer ###############\")\r\n            print(\"inputs_shape: {}\".format(inputs.shape))\r\n            print(\"outputs1_shape: {}\".format(outputs1.shape))\r\n            print(\"outputs2_shape: {}\".format(outputs2.shape))\r\n            print(\"outputs3_shape: {}\".format(outputs3.shape))\r\n            print(\"outputs4_shape: {}\".format(outputs4.shape))\r\n            print(\"outputs5_shape: {}\".format(outputs5.shape))\r\n            print(\"outputs6_shape: {}\".format(outputs6.shape))\r\n            print(\"outputs7_shape: {}\".format(outputs7.shape))\r\n            print(\"outputs8_shape: {}\".format(outputs8.shape))\r\n            \r\n         # \u9009\u62e9\u662f\u5426\u6253\u5370\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u7684\u53c2\u6570\u548c\u8f93\u51fa\u5185\u5bb9\uff0c\u53ef\u7528\u4e8e\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u7684\u8c03\u8bd5\r\n        if check_content:\r\n            # \u6253\u5370\u5377\u79ef\u5c42\u7684\u53c2\u6570-\u5377\u79ef\u6838\u6743\u91cd\uff0c\u6743\u91cd\u53c2\u6570\u8f83\u591a\uff0c\u6b64\u5904\u53ea\u6253\u5370\u90e8\u5206\u53c2\u6570\r\n            print(\"\\n########## print convolution layer's kernel ###############\")\r\n            print(\"conv1 params -- kernel weights:\", self.conv1.weight&#91;0]&#91;0])\r\n            print(\"conv2 params -- kernel weights:\", self.conv2.weight&#91;0]&#91;0])\r\n            # \u521b\u5efa\u968f\u673a\u6570\uff0c\u968f\u673a\u6253\u5370\u67d0\u4e00\u4e2a\u901a\u9053\u7684\u8f93\u51fa\u503c\r\n            idx1 = np.random.randint(0, outputs1.shape&#91;1])\r\n            idx2 = np.random.randint(0, outputs4.shape&#91;1])\r\n            # \u6253\u5370\u5377\u79ef-\u6c60\u5316\u540e\u7684\u7ed3\u679c\uff0c\u4ec5\u6253\u5370batch\u4e2d\u7b2c\u4e00\u4e2a\u56fe\u50cf\u5bf9\u5e94\u7684\u7279\u5f81\r\n            print(\"\\nThe {}th channel of conv1 layer: \".format(idx1), outputs1&#91;0]&#91;idx1])\r\n            print(\"The {}th channel of conv2 layer: \".format(idx2), outputs4&#91;0]&#91;idx2])\r\n            print(\"The output of last layer:\", outputs8&#91;0], '\\n')\r\n        \r\n        return outputs8\r\n            \r\n# \u8bad\u7ec3\u914d\u7f6e\uff0c\u5e76\u542f\u52a8\u8bad\u7ec3\u8fc7\u7a0b\r\nmodel = Mnist()\r\nmodel=model.cuda()\r\nmodel.train(mode=True)\r\n#\u8c03\u7528\u52a0\u8f7d\u6570\u636e\u7684\u51fd\u6570\r\ntrain_loader = load_data('train')\r\noptimizer = optim.Adam(model.parameters(),lr= 0.001)\r\n\r\nBATCHSIZE=100\r\nEPOCH_NUM = 10\r\nfor epoch_id in range(EPOCH_NUM):\r\n    correct=0\r\n    for batch_id, data in enumerate(train_loader()):\r\n        #\u51c6\u5907\u6570\u636e\uff0c\u53d8\u5f97\u66f4\u52a0\u7b80\u6d01\r\n        image_data, label_data = data\r\n        image = torch.tensor(image_data).cuda()\r\n        label = torch.tensor(label_data).cuda()\r\n        image = torch.reshape(image,&#91;image.shape&#91;0],1,28,28]) \r\n        #\u524d\u5411\u8ba1\u7b97\u7684\u8fc7\u7a0b\r\n        if batch_id==0 and epoch_id==0:\r\n            predict=model(image,check_shape=True)\r\n        elif batch_id==401:\r\n            predict=model(image,check_content=True)\r\n        else:\r\n            predict = model(image)\r\n        predict_label = torch.max(predict,1)&#91;1]\r\n        correct += (predict_label == label.squeeze(dim=1)).sum()\r\n        #\u8ba1\u7b97\u635f\u5931\uff0c\u53d6\u4e00\u4e2a\u6279\u6b21\u6837\u672c\u635f\u5931\u7684\u5e73\u5747\u503c\r\n        loss = F.cross_entropy(predict, label.squeeze(dim=1)).cuda()\r\n        avg_loss = torch.mean(loss)\r\n        #\u6bcf\u8bad\u7ec3\u4e86150\u6279\u6b21\u7684\u6570\u636e\uff0c\u6253\u5370\u4e0b\u5f53\u524dLoss\u7684\u60c5\u51b5\r\n        if batch_id!=0 and batch_id % 150 == 0:\r\n            print(\"epoch: {}, batch: {}, loss is: {}, Accuracy:{:.3f}\".format(epoch_id, batch_id, avg_loss.cpu().detach().numpy(),\\\r\n                                                                              correct\/(BATCHSIZE*150)))\r\n            correct=0\r\n        #\u540e\u5411\u4f20\u64ad\uff0c\u66f4\u65b0\u53c2\u6570\u7684\u8fc7\u7a0b\r\n        avg_loss.backward()\r\n        optimizer.step()\r\n        model.zero_grad()\ntorch.save(model.state_dict, 'mnist_test.pdparams')<\/code><\/pre>\n\n\n\n<h2 class=\"wp-block-heading\">\u52a0\u5165\u6821\u9a8c\u6216\u6d4b\u8bd5\uff0c\u66f4\u597d\u8bc4\u4ef7\u6a21\u578b\u6548\u679c<\/h2>\n\n\n\n<p>\u5982\u4e0b\u7a0b\u5e8f\u8bfb\u53d6\u4e0a\u4e00\u6b65\u8bad\u7ec3\u4fdd\u5b58\u7684\u6a21\u578b\u53c2\u6570\uff0c\u8bfb\u53d6\u6821\u9a8c\u6570\u636e\u96c6\uff0c\u5e76\u6d4b\u8bd5\u6a21\u578b\u5728\u6821\u9a8c\u6570\u636e\u96c6\u4e0a\u7684\u6548\u679c\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>def evaluation(model):\r\n    print('start evaluation .......')\r\n    # \u5b9a\u4e49\u9884\u6d4b\u8fc7\u7a0b\r\n    params_file_path = 'mnist.pdparams'\r\n    # \u52a0\u8f7d\u6a21\u578b\u53c2\u6570\r\n    param_dict = paddle.load(params_file_path)\r\n    model.load_dict(param_dict)\r\n\r\n    model.eval()\r\n    eval_loader = load_data('eval')\r\n\r\n    acc_set = &#91;]\r\n    avg_loss_set = &#91;]\r\n    for batch_id, data in enumerate(eval_loader()):\r\n        images, labels = data\r\n        images = paddle.to_tensor(images)\r\n        labels = paddle.to_tensor(labels)\r\n        predicts, acc = model(images, labels)\r\n        loss = F.cross_entropy(input=predicts, label=labels)\r\n        avg_loss = paddle.mean(loss)\r\n        acc_set.append(float(acc.numpy()))\r\n        avg_loss_set.append(float(avg_loss.numpy()))\r\n    \r\n    #\u8ba1\u7b97\u591a\u4e2abatch\u7684\u5e73\u5747\u635f\u5931\u548c\u51c6\u786e\u7387\r\n    acc_val_mean = np.array(acc_set).mean()\r\n    avg_loss_val_mean = np.array(avg_loss_set).mean()\r\n\r\n    print('loss={}, acc={}'.format(avg_loss_val_mean, acc_val_mean))\r\n\r\nmodel = MNIST()\r\nevaluation(model)<\/code><\/pre>\n\n\n\n<p>torch\u7684\u4fdd\u5b58\u6a21\u578b\u548c\u52a0\u8f7d\u548cpaddle\u53c8\u662f\u5173\u952e\u8bcd\u6709\u533a\u522b\uff0c\u52a0\u8f7d\u53c2\u6570\u662fload_state_dict<\/p>\n\n\n\n<p>\u6ce8\u610f\u8fd9\u91cctorch.save(model.state_dict(),'mnist_test.pdparams')\uff0cstate_dict\u5fc5\u987b\u8981\u8ddf\u4e2a\u62ec\u53f7\uff0c\u8981\u4e0d\u7136\u4e4b\u540eload\u65f6\u871c\u6c41\u62a5\u9519\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>def evaluation(model):\r\n    print('start evaluation .......')\r\n    model.load_state_dict(torch.load(r'.\/mnist_test.pdparams'))\r\n\r\n    model.eval()\r\n    eval_loader = load_data('eval')\r\n\r\n    acc_set = &#91;]\r\n    avg_loss_set = &#91;]\r\n    for batch_id, data in enumerate(eval_loader()):\r\n        images, labels = data\r\n        images = torch.tensor(images)\r\n        labels = torch.tensor(labels)\r\n        predicts = model(images)\r\n        predict_labels = torch.max(predicts, 1)&#91;1]\r\n        correct=0\r\n        correct += (predict_labels == labels.squeeze(dim=1)).sum()\r\n        acc=correct\/100\r\n        print(acc)\r\n        loss = F.cross_entropy(input=predicts, target=labels.squeeze(dim=1))\r\n        avg_loss = torch.mean(loss)\r\n        acc_set.append(float(acc.numpy()))\r\n        avg_loss_set.append(float(avg_loss.detach().numpy()))\r\n    \r\n    #\u8ba1\u7b97\u591a\u4e2abatch\u7684\u5e73\u5747\u635f\u5931\u548c\u51c6\u786e\u7387\r\n    acc_val_mean = np.array(acc_set).mean()\r\n    avg_loss_val_mean = np.array(avg_loss_set).mean()\r\n\r\n    print('loss={}, acc={}'.format(avg_loss_val_mean, acc_val_mean))\r\n\r\nmodel = Mnist()\r\nevaluation(model)<\/code><\/pre>\n\n\n\n<h2 class=\"wp-block-heading\">\u52a0\u5165\u6b63\u5219\u5316\u9879\uff0c\u907f\u514d\u6a21\u578b\u8fc7\u62df\u5408<\/h2>\n\n\n\n<p>\u4e3a\u4e86\u9632\u6b62\u6a21\u578b\u8fc7\u62df\u5408\uff0c\u5728\u6ca1\u6709\u6269\u5145\u6837\u672c\u91cf\u7684\u53ef\u80fd\u4e0b\uff0c\u53ea\u80fd\u964d\u4f4e\u6a21\u578b\u7684\u590d\u6742\u5ea6\uff0c\u53ef\u4ee5\u901a\u8fc7\u9650\u5236\u53c2\u6570\u7684\u6570\u91cf\u6216\u53ef\u80fd\u53d6\u503c\uff08\u53c2\u6570\u503c\u5c3d\u91cf\u5c0f\uff09\u5b9e\u73b0\u3002<\/p>\n\n\n\n<p>\u5177\u4f53\u6765\u8bf4\uff0c\u5728\u6a21\u578b\u7684\u4f18\u5316\u76ee\u6807\uff08\u635f\u5931\uff09\u4e2d\u4eba\u4e3a\u52a0\u5165\u5bf9\u53c2\u6570\u89c4\u6a21\u7684\u60e9\u7f5a\u9879\u3002\u5f53\u53c2\u6570\u8d8a\u591a\u6216\u53d6\u503c\u8d8a\u5927\u65f6\uff0c\u8be5\u60e9\u7f5a\u9879\u5c31\u8d8a\u5927\u3002\u901a\u8fc7\u8c03\u6574\u60e9\u7f5a\u9879\u7684\u6743\u91cd\u7cfb\u6570\uff0c\u53ef\u4ee5\u4f7f\u6a21\u578b\u5728\u201c\u5c3d\u91cf\u51cf\u5c11\u8bad\u7ec3\u635f\u5931\u201d\u548c\u201c\u4fdd\u6301\u6a21\u578b\u7684\u6cdb\u5316\u80fd\u529b\u201d\u4e4b\u95f4\u53d6\u5f97\u5e73\u8861\u3002\u6cdb\u5316\u80fd\u529b\u8868\u793a\u6a21\u578b\u5728\u6ca1\u6709\u89c1\u8fc7\u7684\u6837\u672c\u4e0a\u4f9d\u7136\u6709\u6548\u3002\u6b63\u5219\u5316\u9879\u7684\u5b58\u5728\uff0c\u589e\u52a0\u4e86\u6a21\u578b\u5728\u8bad\u7ec3\u96c6\u4e0a\u7684\u635f\u5931\u3002<\/p>\n\n\n\n<p>\u98de\u6868\u652f\u6301\u4e3a\u6240\u6709\u53c2\u6570\u52a0\u4e0a\u7edf\u4e00\u7684\u6b63\u5219\u5316\u9879\uff0c\u4e5f\u652f\u6301\u4e3a\u7279\u5b9a\u7684\u53c2\u6570\u6dfb\u52a0\u6b63\u5219\u5316\u9879\u3002\u524d\u8005\u7684\u5b9e\u73b0\u5982\u4e0b\u4ee3\u7801\u6240\u793a\uff0c\u4ec5\u5728\u4f18\u5316\u5668\u4e2d\u8bbe\u7f6e<code>weight_decay<\/code>\u53c2\u6570\u5373\u53ef\u5b9e\u73b0\u3002\u4f7f\u7528\u53c2\u6570<code>coeff<\/code>\u8c03\u8282\u6b63\u5219\u5316\u9879\u7684\u6743\u91cd\uff0c\u6743\u91cd\u8d8a\u5927\u65f6\uff0c\u5bf9\u6a21\u578b\u590d\u6742\u5ea6\u7684\u60e9\u7f5a\u8d8a\u9ad8\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>    #\u5404\u79cd\u4f18\u5316\u7b97\u6cd5\u5747\u53ef\u4ee5\u52a0\u5165\u6b63\u5219\u5316\u9879\uff0c\u907f\u514d\u8fc7\u62df\u5408\uff0c\u53c2\u6570regularization_coeff\u8c03\u8282\u6b63\u5219\u5316\u9879\u7684\u6743\u91cd\r\n    # opt = paddle.optimizer.SGD(learning_rate=0.01, weight_decay=paddle.regularizer.L2Decay(coeff=0.1), parameters=model.parameters())\r\n    opt = paddle.optimizer.Adam(learning_rate=0.001, weight_decay=paddle.regularizer.L2Decay(coeff=1e-5), parameters=model.parameters())  <\/code><\/pre>\n\n\n\n<p>torch\u7684\u4f18\u5316\u5668\u4e2d\u4e5f\u6709weight_decay\u53c2\u6570<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>optim_wdecay = torch.optim.SGD(net_weight_decay.parameters(), lr=lr_init, momentum=0.9, weight_decay=1e-2)<\/code><\/pre>\n\n\n\n<figure class=\"wp-block-image size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"960\" height=\"960\" src=\"http:\/\/www.gislxz.com\/wp-content\/uploads\/2022\/07\/F1F95FFB6DCC95B1495F078E70D166B7.jpg\" alt=\"\" class=\"wp-image-367\" srcset=\"https:\/\/www.gislxz.com\/wp-content\/uploads\/2022\/07\/F1F95FFB6DCC95B1495F078E70D166B7.jpg 960w, https:\/\/www.gislxz.com\/wp-content\/uploads\/2022\/07\/F1F95FFB6DCC95B1495F078E70D166B7-300x300.jpg 300w, https:\/\/www.gislxz.com\/wp-content\/uploads\/2022\/07\/F1F95FFB6DCC95B1495F078E70D166B7-150x150.jpg 150w, https:\/\/www.gislxz.com\/wp-content\/uploads\/2022\/07\/F1F95FFB6DCC95B1495F078E70D166B7-768x768.jpg 768w\" sizes=\"auto, (max-width: 960px) 100vw, 960px\" \/><\/figure>\n","protected":false},"excerpt":{"rendered":"<p>\u4f18\u5316\u7b97\u6cd5\uff0c\u8bad\u7ec3\u8c03\u8bd5\u4e0e\u4f18\u5316<\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"_jetpack_memberships_contains_paid_content":false,"footnotes":""},"categories":[21],"tags":[],"class_list":["post-392","post","type-post","status-publish","format-standard","hentry","category-21"],"jetpack_featured_media_url":"","jetpack_sharing_enabled":true,"_links":{"self":[{"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/posts\/392","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/comments?post=392"}],"version-history":[{"count":3,"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/posts\/392\/revisions"}],"predecessor-version":[{"id":395,"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/posts\/392\/revisions\/395"}],"wp:attachment":[{"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/media?parent=392"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/categories?post=392"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/tags?post=392"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}