{"id":396,"date":"2022-08-04T23:08:26","date_gmt":"2022-08-04T15:08:26","guid":{"rendered":"http:\/\/www.gislxz.top\/?p=396"},"modified":"2022-08-04T23:09:24","modified_gmt":"2022-08-04T15:09:24","slug":"%e6%b7%b1%e5%ba%a6%e5%ad%a6%e4%b9%a0%e7%ac%94%e8%ae%b0%ef%bc%888%ef%bc%89","status":"publish","type":"post","link":"https:\/\/www.gislxz.com\/index.php\/2022\/08\/04\/%e6%b7%b1%e5%ba%a6%e5%ad%a6%e4%b9%a0%e7%ac%94%e8%ae%b0%ef%bc%888%ef%bc%89\/","title":{"rendered":"\u6df1\u5ea6\u5b66\u4e60\u7b14\u8bb0\uff088\uff09"},"content":{"rendered":"\n<p>paddle\u5b98\u65b9\u6559\u7a0b\u7684\u4e0b\u4e00\u7ae0\u662f<a href=\"https:\/\/aistudio.baidu.com\/aistudio\/projectdetail\/1607211\" target=\"_blank\"  rel=\"nofollow\" >\u6a21\u578b\u52a0\u8f7d\u53ca\u6062\u590d\u8bad\u7ec3<\/a><\/p>\n\n\n\n<p>\u9996\u5148\u662f\u6570\u636e\u52a0\u8f7d\u548c\u6a21\u578b\u5b9a\u4e49\uff0c\u4e0e\u4e4b\u524d\u4e00\u81f4<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>import os\nimport random\nimport paddle\nfrom paddle.nn import Conv2D, MaxPool2D, Linear\nimport paddle.nn.functional as F\nimport numpy as np\nfrom PIL import Image\n\nimport gzip\nimport json\n\npaddle.seed(0)\nrandom.seed(0)\nnp.random.seed(0)\n\n# \u6570\u636e\u6587\u4ef6\ndatafile = '.\/work\/mnist.json.gz'\nprint('loading mnist dataset from {} ......'.format(datafile))\ndata = json.load(gzip.open(datafile))\ntrain_set, val_set, eval_set = data\n\n# \u6570\u636e\u96c6\u76f8\u5173\u53c2\u6570\uff0c\u56fe\u7247\u9ad8\u5ea6IMG_ROWS, \u56fe\u7247\u5bbd\u5ea6IMG_COLS\nIMG_ROWS = 28\nIMG_COLS = 28\nimgs, labels = train_set&#91;0], train_set&#91;1]\nprint(\"\u8bad\u7ec3\u6570\u636e\u96c6\u6570\u91cf: \", len(imgs))\nassert len(imgs) == len(labels), \\\n        \"length of train_imgs({}) should be the same as train_labels({})\".format(\n                len(imgs), len(labels))\n                \nfrom paddle.io import Dataset\n\nclass MnistDataset(Dataset):\n    def __init__(self):\n        self.IMG_COLS = 28\n        self.IMG_ROWS = 28\n    def __getitem__(self, idx):\n        image = train_set&#91;0]&#91;idx]\n        image = np.array(image)\n        image = image.reshape((1, IMG_ROWS,IMG_COLS)).astype('float32')\n        label = train_set&#91;1]&#91;idx]\n        label = np.array(label)\n        label = label.astype('int64')\n        return image, label\n    def __len__(self):\n        return len(imgs)\n\n\n#\u8c03\u7528\u52a0\u8f7d\u6570\u636e\u7684\u51fd\u6570\ndataset = MnistDataset()\ntrain_loader = paddle.io.DataLoader(dataset, batch_size=100, shuffle=False, return_list=True)\n\n# \u5b9a\u4e49\u6a21\u578b\u7ed3\u6784\nclass MNIST(paddle.nn.Layer):\n     def __init__(self):\n         super(MNIST, self).__init__()\n         \n         # \u5b9a\u4e49\u5377\u79ef\u5c42\uff0c\u8f93\u51fa\u7279\u5f81\u901a\u9053out_channels\u8bbe\u7f6e\u4e3a20\uff0c\u5377\u79ef\u6838\u7684\u5927\u5c0fkernel_size\u4e3a5\uff0c\u5377\u79ef\u6b65\u957fstride=1\uff0cpadding=2\n         self.conv1 = Conv2D(in_channels=1, out_channels=20, kernel_size=5, stride=1, padding=2)\n         # \u5b9a\u4e49\u6c60\u5316\u5c42\uff0c\u6c60\u5316\u5c42\u5377\u79ef\u6838kernel_size\u4e3a2\uff0c\u6c60\u5316\u6b65\u957f\u4e3a2\n         self.max_pool1 = MaxPool2D(kernel_size=2, stride=2)\n         # \u5b9a\u4e49\u5377\u79ef\u5c42\uff0c\u8f93\u51fa\u7279\u5f81\u901a\u9053out_channels\u8bbe\u7f6e\u4e3a20\uff0c\u5377\u79ef\u6838\u7684\u5927\u5c0fkernel_size\u4e3a5\uff0c\u5377\u79ef\u6b65\u957fstride=1\uff0cpadding=2\n         self.conv2 = Conv2D(in_channels=20, out_channels=20, kernel_size=5, stride=1, padding=2)\n         # \u5b9a\u4e49\u6c60\u5316\u5c42\uff0c\u6c60\u5316\u5c42\u5377\u79ef\u6838kernel_size\u4e3a2\uff0c\u6c60\u5316\u6b65\u957f\u4e3a2\n         self.max_pool2 = MaxPool2D(kernel_size=2, stride=2)\n         # \u5b9a\u4e49\u4e00\u5c42\u5168\u8fde\u63a5\u5c42\uff0c\u8f93\u51fa\u7ef4\u5ea6\u662f10\n         self.fc = Linear(in_features=980, out_features=10)\n         \n    # \u5b9a\u4e49\u7f51\u7edc\u524d\u5411\u8ba1\u7b97\u8fc7\u7a0b\uff0c\u5377\u79ef\u540e\u7d27\u63a5\u7740\u4f7f\u7528\u6c60\u5316\u5c42\uff0c\u6700\u540e\u4f7f\u7528\u5168\u8fde\u63a5\u5c42\u8ba1\u7b97\u6700\u7ec8\u8f93\u51fa\n     def forward(self, inputs, label):\n         x = self.conv1(inputs)\n         x = F.relu(x)\n         x = self.max_pool1(x)\n         x = self.conv2(x)\n         x = F.relu(x)\n         x = self.max_pool2(x)\n         x = paddle.reshape(x, &#91;x.shape&#91;0], -1])\n         x = self.fc(x)\n         x = F.softmax(x)\n         if label is not None:\n             acc = paddle.metric.accuracy(input=x, label=label)\n             return x, acc\n         else:\n             return x<\/code><\/pre>\n\n\n\n<p>\u63a5\u7740\u6559\u7a0b\u4ecb\u7ecd\u4e86\u52a8\u6001\u5b66\u4e60\u7387lr\u600e\u4e48\u8bbe\u7f6e<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>class paddle.optimizer.lr.PolynomialDecay (learningrate, decaysteps, endlr=0.0001, power=1.0, cycle=False, lastepoch=- 1, verbose=False)<\/code><\/pre>\n\n\n\n<pre class=\"wp-block-code\"><code>#\u5728\u4f7f\u7528GPU\u673a\u5668\u65f6\uff0c\u53ef\u4ee5\u5c06use_gpu\u53d8\u91cf\u8bbe\u7f6e\u6210True\nuse_gpu = False\npaddle.set_device('gpu:0') if use_gpu else paddle.set_device('cpu')\n\nEPOCH_NUM = 5\nBATCH_SIZE = 100\n\npaddle.seed(0)\n\ndef train(model):\n\n    model.train()\n\n    BATCH_SIZE = 100\n    # \u5b9a\u4e49\u5b66\u4e60\u7387\uff0c\u5e76\u52a0\u8f7d\u4f18\u5316\u5668\u53c2\u6570\u5230\u6a21\u578b\u4e2d\n    total_steps = (int(50000\/\/BATCH_SIZE) + 1) * EPOCH_NUM\n    lr = paddle.optimizer.lr.PolynomialDecay(learning_rate=0.01, decay_steps=total_steps, end_lr=0.001)\n    # \u4f7f\u7528Adam\u4f18\u5316\u5668\n    opt = paddle.optimizer.Adam(learning_rate=lr, parameters=model.parameters())\n    \n    for epoch_id in range(EPOCH_NUM):\n        for batch_id, data in enumerate(train_loader()):\n            #\u51c6\u5907\u6570\u636e\uff0c\u53d8\u5f97\u66f4\u52a0\u7b80\u6d01\n            image_data = data&#91;0].reshape(&#91;BATCH_SIZE, 1, 28, 28])\n            label_data = data&#91;1].reshape(&#91;BATCH_SIZE, 1])\n            image = paddle.to_tensor(image_data)\n            label = paddle.to_tensor(label_data)\n            # if batch_id&lt;10:\n                # print(label.reshape(&#91;-1])&#91;:10])\n            #\u524d\u5411\u8ba1\u7b97\u7684\u8fc7\u7a0b\n            predict, acc = model(image, label)\n            avg_acc = paddle.mean(acc)\n            #\u8ba1\u7b97\u635f\u5931\uff0c\u4f7f\u7528\u4ea4\u53c9\u71b5\u635f\u5931\u51fd\u6570\uff0c\u53d6\u4e00\u4e2a\u6279\u6b21\u6837\u672c\u635f\u5931\u7684\u5e73\u5747\u503c\n            loss = F.cross_entropy(predict, label)\n            avg_loss = paddle.mean(loss)\n            \n            #\u6bcf\u8bad\u7ec3\u4e86200\u6279\u6b21\u7684\u6570\u636e\uff0c\u6253\u5370\u4e0b\u5f53\u524dLoss\u7684\u60c5\u51b5\n            if batch_id % 200 == 0:\n                print(\"epoch: {}, batch: {}, loss is: {}, acc is {}\".format(epoch_id, batch_id, avg_loss.numpy(), avg_acc.numpy()))\n            \n            #\u540e\u5411\u4f20\u64ad\uff0c\u66f4\u65b0\u53c2\u6570\u7684\u8fc7\u7a0b\n            avg_loss.backward()\n            opt.step()\n            opt.clear_grad()\n\n            \n    \n            # \u4fdd\u5b58\u6a21\u578b\u53c2\u6570\u548c\u4f18\u5316\u5668\u7684\u53c2\u6570\n            paddle.save(model.state_dict(), '.\/checkpoint\/mnist_epoch{}'.format(epoch_id)+'.pdparams')\n            paddle.save(opt.state_dict(), '.\/checkpoint\/mnist_epoch{}'.format(epoch_id)+'.pdopt')\n    print(opt.state_dict().keys())\n\nmodel = MNIST()\ntrain(model)\n\nprint(model.state_dict().keys())<\/code><\/pre>\n\n\n\n<p>\u4e0a\u9762\u7684\u4ee3\u7801\u4e2d\u6bcf\u4e00\u8f6e\u90fd\u4fdd\u5b58\u4e86\u76f8\u5e94\u7684\u6a21\u578b\u53c2\u6570\u548c\u4f18\u5316\u5668\u53c2\u6570<\/p>\n\n\n\n<p>\u4e0b\u9762\u7684\u4ee3\u7801\u5c06\u5c55\u793a\u6062\u590d\u8bad\u7ec3\u7684\u8fc7\u7a0b\uff0c\u5e76\u9a8c\u8bc1\u6062\u590d\u8bad\u7ec3\u662f\u5426\u6210\u529f\u3002\u5176\u4e2d\uff0c\u6211\u4eec\u91cd\u65b0\u5b9a\u4e49\u4e00\u4e2a<code>train_again()<\/code>\u8bad\u7ec3\u51fd\u6570\uff0c\u52a0\u8f7d\u6a21\u578b\u53c2\u6570\u5e76\u4ece\u7b2c\u4e00\u4e2aepoch\u5f00\u59cb\u8bad\u7ec3\uff0c\u4ee5\u4fbf\u8bfb\u8005\u53ef\u4ee5\u6821\u9a8c\u6062\u590d\u8bad\u7ec3\u540e\u7684\u635f\u5931\u53d8\u5316\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>params_path = \".\/checkpoint\/mnist_epoch0\"\n#\u5728\u4f7f\u7528GPU\u673a\u5668\u65f6\uff0c\u53ef\u4ee5\u5c06use_gpu\u53d8\u91cf\u8bbe\u7f6e\u6210True\nuse_gpu = True\npaddle.set_device('gpu:0') if use_gpu else paddle.set_device('cpu')\ndef train_again(model):\n    model.train()\n\n    # \u8bfb\u53d6\u53c2\u6570\u6587\u4ef6\n    params_dict = paddle.load(params_path+'.pdparams')\n    opt_dict = paddle.load(params_path+'.pdopt')\n    # \u52a0\u8f7d\u53c2\u6570\u5230\u6a21\u578b\n    model.set_state_dict(params_dict)\n    \n    EPOCH_NUM = 5\n    BATCH_SIZE = 100\n    # \u5b9a\u4e49\u5b66\u4e60\u7387\uff0c\u5e76\u52a0\u8f7d\u4f18\u5316\u5668\u53c2\u6570\u5230\u6a21\u578b\u4e2d\n    total_steps = (int(50000\/\/BATCH_SIZE) + 1) * EPOCH_NUM\n    lr = paddle.optimizer.lr.PolynomialDecay(learning_rate=0.01, decay_steps=total_steps, end_lr=0.001)\n    # \u4f7f\u7528Adam\u4f18\u5316\u5668\n    opt = paddle.optimizer.Adam(learning_rate=lr, parameters=model.parameters())\n    # \u52a0\u8f7d\u53c2\u6570\u5230\u4f18\u5316\u5668\n    opt.set_state_dict(opt_dict)\n\n    for epoch_id in range(1, EPOCH_NUM):\n        for batch_id, data in enumerate(train_loader()):\n            #\u51c6\u5907\u6570\u636e\uff0c\u53d8\u5f97\u66f4\u52a0\u7b80\u6d01\n            image_data = data&#91;0].reshape(&#91;BATCH_SIZE, 1, 28, 28])\n            label_data = data&#91;1].reshape(&#91;BATCH_SIZE, 1])\n            image = paddle.to_tensor(image_data)\n            label = paddle.to_tensor(label_data)\n            \n            #\u524d\u5411\u8ba1\u7b97\u7684\u8fc7\u7a0b\n            predict, acc = model(image, label)\n\n            avg_acc = paddle.mean(acc)\n            #\u8ba1\u7b97\u635f\u5931\uff0c\u4f7f\u7528\u4ea4\u53c9\u71b5\u635f\u5931\u51fd\u6570\uff0c\u53d6\u4e00\u4e2a\u6279\u6b21\u6837\u672c\u635f\u5931\u7684\u5e73\u5747\u503c\n            loss = F.cross_entropy(predict, label)\n            avg_loss = paddle.mean(loss)\n            \n            #\u6bcf\u8bad\u7ec3\u4e86200\u6279\u6b21\u7684\u6570\u636e\uff0c\u6253\u5370\u4e0b\u5f53\u524dLoss\u7684\u60c5\u51b5\n            if batch_id % 200 == 0:\n                print(\"epoch: {}, batch: {}, loss is: {}, acc is {}\".format(epoch_id, batch_id, avg_loss.numpy(), avg_acc.numpy()))\n            \n            #\u540e\u5411\u4f20\u64ad\uff0c\u66f4\u65b0\u53c2\u6570\u7684\u8fc7\u7a0b\n            # print(opt.state_dict())\n            avg_loss.backward()\n            opt.step()\n            opt.clear_grad()\n\nmodel = MNIST()\ntrain_again(model)<\/code><\/pre>\n\n\n\n<p>\u8fd9\u4e00\u7ae0\u57fa\u672c\u6ca1\u8bb2\u5565\u4e1c\u897f\uff0c\u5c31\u662fsave load\u7684\u4f7f\u7528<\/p>\n\n\n\n<p>\u6559\u7a0b\u4e0b\u4e00\u7ae0\u662f\u52a8\u9759\u8f6c\u6362\uff0c\u4e5f\u662fpaddle\u5f88\u795e\u5947\u7684\u529f\u80fd\uff0cpytorch\u6211\u5012\u662f\u6ca1\u6709\u627e\u5230\u80fd\u52a8\u8f6c\u9759\u7684\u8d44\u6599\u3002<\/p>\n\n\n\n<p><a href=\"https:\/\/aistudio.baidu.com\/aistudio\/projectdetail\/4405843\" target=\"_blank\"  rel=\"nofollow\" >\u3010\u624b\u5199\u6570\u5b57\u8bc6\u522b\u3011\u4e4b\u52a8\u8f6c\u9759\u90e8\u7f72<\/a><\/p>\n\n\n\n<figure class=\"wp-block-image size-full\"><img loading=\"lazy\" decoding=\"async\" width=\"310\" height=\"310\" src=\"http:\/\/www.gislxz.com\/wp-content\/uploads\/2022\/07\/112.png\" alt=\"\" class=\"wp-image-361\" srcset=\"https:\/\/www.gislxz.com\/wp-content\/uploads\/2022\/07\/112.png 310w, https:\/\/www.gislxz.com\/wp-content\/uploads\/2022\/07\/112-300x300.png 300w, https:\/\/www.gislxz.com\/wp-content\/uploads\/2022\/07\/112-150x150.png 150w\" sizes=\"auto, (max-width: 310px) 100vw, 310px\" \/><\/figure>\n","protected":false},"excerpt":{"rendered":"<p>\u6a21\u578b\u9636\u6bb5\u6027\u4fdd\u5b58\u548c\u52a0\u8f7d\uff0c\u52a8\u8f6c\u9759<\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"_jetpack_memberships_contains_paid_content":false,"footnotes":""},"categories":[21],"tags":[],"class_list":["post-396","post","type-post","status-publish","format-standard","hentry","category-21"],"jetpack_featured_media_url":"","jetpack_sharing_enabled":true,"_links":{"self":[{"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/posts\/396","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/comments?post=396"}],"version-history":[{"count":2,"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/posts\/396\/revisions"}],"predecessor-version":[{"id":398,"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/posts\/396\/revisions\/398"}],"wp:attachment":[{"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/media?parent=396"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/categories?post=396"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/www.gislxz.com\/index.php\/wp-json\/wp\/v2\/tags?post=396"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}