{"id":254261,"date":"2022-10-12T10:09:57","date_gmt":"2022-10-12T02:09:57","guid":{"rendered":"https:\/\/lrxjmw.cn\/?p=254261"},"modified":"2022-10-04T11:12:18","modified_gmt":"2022-10-04T03:12:18","slug":"multi-task-learning-model-with-pytorch","status":"publish","type":"post","link":"https:\/\/lrxjmw.cn\/multi-task-learning-model-with-pytorch.html","title":{"rendered":"\u591a\u4efb\u52a1\u5b66\u4e60\u6a21\u578b\u7528Pytorch\u521b\u5efa\u8bd5\u4e00\u4e0b"},"content":{"rendered":"

MTL\u6700\u8457\u540d\u7684\u4f8b\u5b50\u53ef\u80fd\u662f\u7279\u65af\u62c9\u7684\u81ea\u52a8\u9a7e\u9a76\u7cfb\u7edf\u3002\u5728\u81ea\u52a8\u9a7e\u9a76\u4e2d\u9700\u8981\u540c\u65f6\u5904\u7406\u5927\u91cf\u4efb\u52a1\uff0c\u5982\u7269\u4f53\u68c0\u6d4b\u3001\u6df1\u5ea6\u4f30\u8ba1\u30013D\u91cd\u5efa\u3001\u89c6\u9891\u5206\u6790\u3001\u8ddf\u8e2a\u7b49\uff0c\u4f60\u53ef\u80fd\u8ba4\u4e3a\u9700\u898110\u4e2a\u4ee5\u4e0a\u7684\u6df1\u5ea6\u5b66\u4e60\u6a21\u578b\uff0c\u4f46\u4e8b\u5b9e\u5e76\u975e\u5982\u6b64\u3002<\/p>\n

\"\"<\/a><\/p>\n

Top view of the hands of a woman tapping the keyboard of a laptop on a wooden desk with smart phone leaning near - technology, multitasking, business concept<\/p>\n<\/div>\n

HydraNet\u4ecb\u7ecd<\/strong><\/div>\n

\u4e00\u822c\u6765\u8bf4\u591a\u4efb\u52a1\u5b66\u7684\u6a21\u578b\u67b6\u6784\u975e\u5e38\u7b80\u5355\uff1a\u4e00\u4e2a\u9aa8\u5e72\u7f51\u7edc\u4f5c\u4e3a\u7279\u5f81\u7684\u63d0\u53d6\uff0c\u7136\u540e\u9488\u5bf9\u4e0d\u540c\u7684\u4efb\u52a1\u521b\u5efa\u591a\u4e2a\u5934\u3002\u5229\u7528\u5355\u4e00\u6a21\u578b\u89e3\u51b3\u591a\u4e2a\u4efb\u52a1\u3002<\/p>\n

\u4e0a\u56fe\u53ef\u4ee5\u770b\u5230\uff0c\u7279\u5f81\u63d0\u53d6\u6a21\u578b\u63d0\u53d6\u56fe\u50cf\u7279\u5f81\u3002\u8f93\u51fa\u6700\u540e\u88ab\u5206\u5272\u6210\u591a\u4e2a\u5934\uff0c\u6bcf\u4e2a\u5934\u8d1f\u8d23\u4e00\u4e2a\u7279\u5b9a\u7684\u60c5\u51b5\uff0c\u7531\u4e8e\u5b83\u4eec\u5f7c\u6b64\u72ec\u7acb\u53ef\u4ee5\u5355\u72ec\u8fdb\u884c\u5fae\u8c03!<\/p>\n

\u7279\u65af\u62c9\u7684\u8bb2\u6f14\u4e2d\u8be6\u7ec6\u7684\u8bf4\u660e\u8fd9\u4e2a\u6a21\u578b\uff08youtube\uff1av=3SypMvnQT_s\uff09<\/p>\n

\u591a\u4efb\u52a1\u5b66\u4e60\u9879\u76ee<\/strong><\/div>\n

\u5728\u672c\u6587\u4e2d\uff0c\u6211\u4eec\u5c06\u4ecb\u7ecd\u5982\u4f55\u5728Pytorch\u4e2d\u5b9e\u73b0\u4e00\u4e2a\u66f4\u7b80\u5355\u7684HydraNet\u3002\u8fd9\u91cc\u5c06\u4f7f\u7528UTK Face\u6570\u636e\u96c6\uff0c\u8fd9\u662f\u4e00\u4e2a\u5e26\u67093\u4e2a\u6807\u7b7e(\u6027\u522b\u3001\u79cd\u65cf\u3001\u5e74\u9f84)\u7684\u5206\u7c7b\u6570\u636e\u96c6\u3002<\/p>\n

\u6211\u4eec\u7684HydraNet\u5c06\u6709\u4e09\u4e2a\u72ec\u7acb\u7684\u5934\uff0c\u5b83\u4eec\u90fd\u662f\u4e0d\u540c\u7684\uff0c\u56e0\u4e3a\u5e74\u9f84\u7684\u9884\u6d4b\u662f\u4e00\u4e2a\u56de\u5f52\u4efb\u52a1\uff0c\u79cd\u65cf\u7684\u9884\u6d4b\u662f\u4e00\u4e2a\u591a\u7c7b\u5206\u7c7b\u95ee\u9898\uff0c\u6027\u522b\u7684\u9884\u6d4b\u662f\u4e00\u4e2a\u4e8c\u5143\u5206\u7c7b\u4efb\u52a1\u3002<\/p>\n

\u6bcf\u4e00\u4e2aPytorch \u7684\u6df1\u5ea6\u5b66\u4e60\u7684\u9879\u76ee\u90fd\u5e94\u8be5\u4ece\u5b9a\u4e49Dataset\u548cDataLoader\u5f00\u59cb\u3002<\/p>\n

\u5728\u8fd9\u4e2a\u6570\u636e\u96c6\u4e2d\uff0c\u901a\u8fc7\u56fe\u50cf\u7684\u540d\u79f0\u5b9a\u4e49\u4e86\u8fd9\u4e9b\u6807\u7b7e\uff0c\u4f8b\u5982UTKFace\/30_0_3_20170117145159065.jpg.chip.jpg<\/p>\n

    30\u5c81\u662f\u5e74\u9f84<\/ul>\n
      0\u4e3a\u6027\u522b(0:\u7537\u6027\uff0c1:\u5973\u6027)<\/ul>\n
        3\u662f\u79cd\u65cf(0:\u767d\u4eba\uff0c1:\u9ed1\u4eba\uff0c2:\u4e9a\u6d32\u4eba\uff0c3:\u5370\u5ea6\u4eba\uff0c4:\u5176\u4ed6)<\/ul>\n

        \u6240\u4ee5\u6211\u4eec\u7684\u81ea\u5b9a\u4e49Dataset\u53ef\u4ee5\u8fd9\u6837\u5199:<\/p>\n

        \r\nclass UTKFace(Dataset):\r\n    def __init__(self, image_paths):\r\n        self.transform = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])\r\n        self.image_paths = image_paths\r\n        self.images = []\r\n        self.ages = []\r\n        self.genders = []\r\n        self.races = []\r\n         \r\n        for path in image_paths:\r\n            filename = path[8:].split(\"_\")\r\n             \r\n            if len(filename)==4:\r\n                self.images.append(path)\r\n                self.ages.append(int(filename[0]))\r\n                self.genders.append(int(filename[1]))\r\n                self.races.append(int(filename[2]))\r\n \r\n    def __len__(self):\r\n          return len(self.images)\r\n \r\n    def __getitem__(self, index):\r\n            img = Image.open(self.images[index]).convert('RGB')\r\n            img = self.transform(img)\r\n           \r\n            age = self.ages[index]\r\n            gender = self.genders[index]\r\n            eth = self.races[index]\r\n             \r\n            sample = {'image':img, 'age': age, 'gender': gender, 'ethnicity':eth}\r\n             \r\n            return sample\r\n<\/pre>\n

        \u7b80\u5355\u7684\u505a\u4e2a\u4ecb\u7ecd\uff1a<\/p>\n

        __init__\u65b9\u6cd5\u521d\u59cb\u5316\u6211\u4eec\u7684\u81ea\u5b9a\u4e49\u6570\u636e\u96c6\uff0c\u8d1f\u8d23\u521d\u59cb\u5316\u5404\u79cd\u8f6c\u6362\u548c\u4ece\u56fe\u50cf\u8def\u5f84\u4e2d\u63d0\u53d6\u6807\u7b7e\u3002<\/p>\n

        __get_item__\u5c06:\u5b83\u5c06\u52a0\u8f7d\u4e00\u5f20\u56fe\u50cf\uff0c\u5e94\u7528\u5fc5\u8981\u7684\u8f6c\u6362\uff0c\u83b7\u53d6\u6807\u7b7e\uff0c\u5e76\u8fd4\u56de\u6570\u636e\u96c6\u7684\u4e00\u4e2a\u5143\u7d20\uff0c\u4e5f\u5c31\u662f\u8bf4\u8fd9\u4e2a\u65b9\u6cd5\u4f1a\u8fd4\u56de\u6570\u636e\u96c6\u4e2d\u7684\u5355\u6761\u6570\u636e\uff08\u5355\u4e2a\u6837\u672c\uff09<\/p>\n

        \u7136\u540e\u6211\u4eec\u5b9a\u4e49dataloader<\/p>\n

        \r\ntrain_dataloader = DataLoader(UTKFace(train_dataset), shuffle=True, batch_size=BATCH_SIZE)\r\n val_dataloader = DataLoader(UTKFace(valid_dataset), shuffle=False, batch_size=BATCH_SIZE)\r\n<\/pre>\n

        \u4e0b\u9762\u6211\u4eec\u5b9a\u4e49\u6a21\u578b\uff0c\u8fd9\u91cc\u4f7f\u7528\u4e00\u4e2a\u9884\u8bad\u7ec3\u7684\u6a21\u578b\u4f5c\u4e3a\u9aa8\u5e72\uff0c\u7136\u540e\u521b\u5efa3\u4e2a\u5934\u3002\u5206\u522b\u4ee3\u8868\u5e74\u9f84\uff0c\u6027\u522b\u548c\u79cd\u65cf\u3002<\/p>\n

        \r\nclass HydraNet(nn.Module):\r\n    def __init__(self):\r\n        super().__init__()\r\n        self.net = models.resnet18(pretrained=True)\r\n        self.n_features = self.net.fc.in_features\r\n        self.net.fc = nn.Identity()\r\n \r\n        self.net.fc1 = nn.Sequential(OrderedDict(\r\n            [('linear', nn.Linear(self.n_features,self.n_features)),\r\n            ('relu1', nn.ReLU()),\r\n            ('final', nn.Linear(self.n_features, 1))]))\r\n \r\n        self.net.fc2 = nn.Sequential(OrderedDict(\r\n            [('linear', nn.Linear(self.n_features,self.n_features)),\r\n            ('relu1', nn.ReLU()),\r\n            ('final', nn.Linear(self.n_features, 1))]))\r\n \r\n        self.net.fc3 = nn.Sequential(OrderedDict(\r\n            [('linear', nn.Linear(self.n_features,self.n_features)),\r\n            ('relu1', nn.ReLU()),\r\n            ('final', nn.Linear(self.n_features, 5))]))\r\n         \r\n    def forward(self, x):\r\n        age_head = self.net.fc1(self.net(x))\r\n        gender_head = self.net.fc2(self.net(x))\r\n        ethnicity_head = self.net.fc3(self.net(x))\r\n        return age_head, gender_head, ethnicity_head\r\n<\/pre>\n

        forward\u65b9\u6cd5\u8fd4\u56de\u6bcf\u4e2a\u5934\u7684\u7ed3\u679c\u3002<\/p>\n

        \u635f\u5931\u4f5c\u4e3a\u4f18\u5316\u7684\u57fa\u7840\u65f6\u5341\u5206\u91cd\u8981\u7684\uff0c\u56e0\u4e3a\u5b83\u5c06\u4f1a\u5f71\u54cd\u5230\u6a21\u578b\u7684\u6027\u80fd\uff0c\u6211\u4eec\u80fd\u60f3\u5230\u7684\u6700\u7b80\u5355\u7684\u4e8b\u5c31\u662f\u5730\u628a\u635f\u5931\u76f8\u52a0\uff1a<\/p>\n

        \r\nL = L1 + L2 + L3\r\n<\/pre>\n

        \u4f46\u662f\u6211\u4eec\u7684\u6a21\u578b\u4e2d<\/p>\n

        L1:\u4e0e\u5e74\u9f84\u76f8\u5173\u7684\u635f\u5931\uff0c\u5982\u5e73\u5747\u7edd\u5bf9\u8bef\u5dee\uff0c\u56e0\u4e3a\u5b83\u662f\u56de\u5f52\u635f\u5931\u3002
        \nL2:\u4e0e\u79cd\u65cf\u76f8\u5173\u7684\u4ea4\u53c9\u71b5\uff0c\u5b83\u662f\u4e00\u4e2a\u591a\u7c7b\u522b\u7684\u5206\u7c7b\u635f\u5931\u3002
        \nL3:\u6027\u522b\u6709\u5173\u7684\u635f\u5931\uff0c\u4f8b\u5982\u4e8c\u5143\u4ea4\u53c9\u71b5\u3002
        \n\u8fd9\u91cc\u635f\u5931\u7684\u8ba1\u7b97\u6700\u5927\u95ee\u9898\u662f\u635f\u5931\u7684\u91cf\u7ea7\u662f\u4e0d\u4e00\u6837\u7684\uff0c\u5e76\u4e14\u635f\u5931\u7684\u6743\u91cd\u4e5f\u662f\u4e0d\u76f8\u540c\u7684\uff0c\u8fd9\u662f\u4e00\u4e2a\u4e00\u76f4\u5728\u88ab\u6df1\u5165\u7814\u7a76\u7684\u95ee\u9898\uff0c\u6211\u4eec\u8fd9\u91cc\u6682\u4e0d\u505a\u8ba8\u8bba\uff0c\u6211\u4eec\u53ea\u4f7f\u7528\u7b80\u5355\u7684\u76f8\u52a0\uff0c\u6240\u4ee5\u6211\u4eec\u7684\u4e00\u4e9b\u8d85\u53c2\u6570\u5982\u4e0b\uff1a<\/p>\n

        \r\nmodel = HydraNet().to(device=device)\r\n \r\n ethnicity_loss = nn.CrossEntropyLoss()\r\n gender_loss = nn.BCELoss()\r\n age_loss = nn.L1Loss()\r\n sig = nn.Sigmoid()\r\n \r\n optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.09)\r\n<\/pre>\n

        \u7136\u540e\u6211\u4eec\u8bad\u7ec3\u7684\u5faa\u73af\u5982\u4e0b\uff1a<\/p>\n

        \r\nfor epoch in range(n_epochs):\r\n    model.train()\r\n    total_training_loss = 0\r\n   \r\n    for i, data in enumerate(tqdm(train_dataloader)):\r\n        inputs = data[\"image\"].to(device=device)\r\n         \r\n        age_label = data[\"age\"].to(device=device)\r\n        gender_label = data[\"gender\"].to(device=device)\r\n        eth_label = data[\"ethnicity\"].to(device=device)\r\n         \r\n        optimizer.zero_grad()\r\n        age_output, gender_output, eth_output = model(inputs)\r\n         \r\n        loss_1 = ethnicity_loss(eth_output, eth_label)\r\n        loss_2 = gender_loss(sig(gender_output), gender_label.unsqueeze(1).float())\r\n        loss_3 = age_loss(age_output, age_label.unsqueeze(1).float())\r\n         \r\n        loss = loss_1 + loss_2 + loss_3\r\n        loss.backward()\r\n        optimizer.step()\r\n         \r\n        total_training_loss += loss\r\n<\/pre>\n

        \u8fd9\u6837\u6211\u4eec\u6700\u7b80\u5355\u7684\u591a\u4efb\u52a1\u5b66\u4e60\u7684\u6d41\u7a0b\u5c31\u5b8c\u6210\u4e86<\/p>\n

        \u5173\u4e8e\u635f\u5931\u7684\u4f18\u5316<\/strong><\/div>\n

        \u591a\u4efb\u52a1\u5b66\u4e60\u7684\u635f\u5931\u51fd\u6570\uff0c\u5bf9\u6bcf\u4e2a\u4efb\u52a1\u7684\u635f\u5931\u8fdb\u884c\u6743\u91cd\u5206\u914d\uff0c\u5728\u8fd9\u4e2a\u8fc7\u7a0b\u4e2d\uff0c\u5fc5\u987b\u4fdd\u8bc1\u6240\u6709\u4efb\u52a1\u540c\u7b49\u91cd\u8981\uff0c\u800c\u4e0d\u80fd\u8ba9\u7b80\u5355\u4efb\u52a1\u4e3b\u5bfc\u6574\u4e2a\u8bad\u7ec3\u8fc7\u7a0b\u3002\u624b\u52a8\u7684\u8bbe\u7f6e\u6743\u91cd\u662f\u4f4e\u6548\u800c\u4e14\u4e0d\u662f\u6700\u4f18\u7684\uff0c\u56e0\u6b64\uff0c\u81ea\u52a8\u7684\u5b66\u4e60\u8fd9\u4e9b\u6743\u91cd\u662f\u5341\u5206\u5fc5\u8981\u7684\uff0c<\/p>\n

        \r\nMulti-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics cvpr_2018\r\n<\/pre>\n

        \u8fd9\u7bc7\u8bba\u6587\u63d0\u51fa\uff0c\u5c06\u4e0d\u540c\u7684loss\u62c9\u5230\u7edf\u4e00\u5c3a\u5ea6\u4e0b\uff0c\u8fd9\u6837\u5c31\u5bb9\u6613\u7edf\u4e00\uff0c\u5177\u4f53\u7684\u529e\u6cd5\u5c31\u662f\u5229\u7528\u540c\u65b9\u5dee\u7684\u4e0d\u786e\u5b9a\u6027\uff0c\u5c06\u4e0d\u786e\u5b9a\u6027\u4f5c\u4e3a\u566a\u58f0\uff0c\u8fdb\u884c\u8bad\u7ec3\u3002<\/p>\n

        \r\nEnd-to-End Multi-Task Learning with Attention cvpr_2019\r\n<\/pre>\n

        \u8fd9\u7bc7\u8bba\u6587\u63d0\u51fa\u4e86\u4e00\u79cd\u53ef\u4ee5\u81ea\u52a8\u8c03\u8282\u6743\u91cd\u7684\u673a\u5236\uff08 Dynamic Weight Average\uff09\uff0c\u4f7f\u5f97\u6743\u91cd\u5206\u914d\u66f4\u52a0\u5408\u7406\uff0c\u5927\u6982\u7684\u610f\u601d\u662f\u6bcf\u4e2a\u4efb\u52a1\u9996\u5148\u8ba1\u7b97\u524d\u4e2aepoch\u5bf9\u5e94\u635f\u5931\u7684\u6bd4\u503c\uff0c\u7136\u540e\u9664\u4ee5\u4e00\u4e2a\u56fa\u5b9a\u7684\u503cT\uff0c\u8fdb\u884cexp\u6620\u5c04\u540e\uff0c\u8ba1\u7b97\u5404\u4e2a\u635f\u5931\u6240\u5360\u6bd4<\/p>\n

        \u6700\u540e\u5982\u679c\u4f60\u5bf9\u591a\u4efb\u52a1\u5b66\u4e60\u611f\u5174\u8da3\uff0c\u53ef\u4ee5\u5148\u770b\u770b\u8fd9\u7bc7\u8bba\u6587\uff1a<\/p>\n

        \r\nA Survey on Multi-Task Learning  arXiv 1707.08114\r\n<\/pre>\n

        \u4ece\u7b97\u6cd5\u5efa\u6a21\u3001\u5e94\u7528\u548c\u7406\u8bba\u5206\u6790\u7684\u89d2\u5ea6\u5bf9MTL\u8fdb\u884c\u4e86\u8c03\u67e5\uff0c\u662f\u5165\u95e8\u7684\u6700\u597d\u7684\u8d44\u6599\u3002<\/p>\n","protected":false},"excerpt":{"rendered":"

        MTL\u6700\u8457\u540d\u7684\u4f8b\u5b50\u53ef\u80fd\u662f\u7279\u65af\u62c9\u7684\u81ea\u52a8\u9a7e\u9a76\u7cfb\u7edf\u3002\u5728\u81ea\u52a8\u9a7e\u9a76\u4e2d\u9700\u8981\u540c\u65f6\u5904\u7406\u5927\u91cf\u4efb\u52a1\uff0c\u5982\u7269\u4f53\u68c0\u6d4b\u3001\u6df1\u5ea6\u4f30\u8ba1\u30013D\u91cd\u5efa […]<\/p>\n","protected":false},"author":1898,"featured_media":254263,"comment_status":"closed","ping_status":"closed","sticky":false,"template":"","format":"standard","meta":{"_acf_changed":false,"footnotes":""},"categories":[55],"tags":[938],"class_list":["post-254261","post","type-post","status-publish","format-standard","has-post-thumbnail","hentry","category-thread","tag-pytorch"],"acf":[],"_links":{"self":[{"href":"https:\/\/lrxjmw.cn\/wp-json\/wp\/v2\/posts\/254261","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/lrxjmw.cn\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/lrxjmw.cn\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/lrxjmw.cn\/wp-json\/wp\/v2\/users\/1898"}],"replies":[{"embeddable":true,"href":"https:\/\/lrxjmw.cn\/wp-json\/wp\/v2\/comments?post=254261"}],"version-history":[{"count":1,"href":"https:\/\/lrxjmw.cn\/wp-json\/wp\/v2\/posts\/254261\/revisions"}],"predecessor-version":[{"id":254264,"href":"https:\/\/lrxjmw.cn\/wp-json\/wp\/v2\/posts\/254261\/revisions\/254264"}],"wp:featuredmedia":[{"embeddable":true,"href":"https:\/\/lrxjmw.cn\/wp-json\/wp\/v2\/media\/254263"}],"wp:attachment":[{"href":"https:\/\/lrxjmw.cn\/wp-json\/wp\/v2\/media?parent=254261"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/lrxjmw.cn\/wp-json\/wp\/v2\/categories?post=254261"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/lrxjmw.cn\/wp-json\/wp\/v2\/tags?post=254261"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}