{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "c9fe2845", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "\n", "import torch\n", "import torch.nn as nn\n", "from torch.utils.data import Dataset, DataLoader\n", "import torch.optim as optim\n", "\n", "import pandas as pd\n", "import numpy as np\n", "import lightgbm as lgb\n", "from sklearn.metrics import mean_squared_error\n", "\n", "df = pd.read_csv('Raotbl6.csv')\n", "df['date'] = pd.to_datetime(df['date'])\n", "df.index = df['date']\n", "df.drop('date', axis=1, inplace=True)\n", "\n", "# make future targets\n", "for i in range(12):\n", " df['rgnp_{}'.format(i)] = df['rgnp'].shift(-i-1)\n", "\n", "df.dropna(inplace=True)\n", "\n", "targets = [item for item in df.columns if 'rgnp_' in item]\n", "\n", "X_train = df.drop(targets, axis=1)[: int(len(df) * 0.8)]\n", "y_train = df[targets][: int(len(df) * 0.8)]\n", "\n", "X_test = df.drop(targets, axis=1)[int(len(df) * 0.8) :]\n", "y_test = df[targets][int(len(df) * 0.8) :]\n" ] }, { "cell_type": "code", "execution_count": null, "id": "f00798f6", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 2, "id": "fa886b54", "metadata": {}, "outputs": [], "source": [ "class Raotbl6Dataset(Dataset):\n", " \"\"\"Face Landmarks dataset.\"\"\"\n", "\n", " def __init__(self, X_train, y_train):\n", " self.X_train = X_train\n", " self.y_train = y_train\n", " \n", " \n", " def __len__(self):\n", " return len(self.X_train)\n", "\n", " def __getitem__(self, idx):\n", " \n", " if torch.is_tensor(idx):\n", " idx = idx.tolist()\n", " \n", " X = torch.Tensor(self.X_train.iloc[idx].values)\n", " y = torch.Tensor(self.y_train.iloc[idx].values)\n", " return X, y" ] }, { "cell_type": "code", "execution_count": 3, "id": "471b9907", "metadata": {}, "outputs": [], "source": [ "class Model(nn.Module):\n", " def __init__(self, in_feats, out_feats=12, hidden_units=32):\n", " super(Model, self).__init__()\n", " \n", " # 输入:(batch_size, z_dim, 1, 1)\n", " self.net = nn.Sequential(\n", " nn.Linear(in_feats, hidden_units),\n", " nn.ReLU(),\n", " nn.Linear(hidden_units, out_feats)\n", " )\n", " \n", " def forward(self, x):\n", " return self.net(x)" ] }, { "cell_type": "code", "execution_count": 4, "id": "c9e68051", "metadata": {}, "outputs": [], "source": [ "NUM_EPOCHS = 500\n", "LEARNING_RATE = 2e-4\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "\n", "\n", "\n", "dataset = Raotbl6Dataset(X_train, y_train)\n", "dataloader = DataLoader(dataset,batch_size=32, shuffle=True)\n", "criterion = torch.nn. MSELoss()\n", "\n", "model = Model(8).to(device)\n", "opt = optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))" ] }, { "cell_type": "code", "execution_count": 5, "id": "a1bb5a6a", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch [0/500] ERROR: 5665249.0000\n", "Epoch [1/500] ERROR: 6766489.5000\n", "Epoch [2/500] ERROR: 6046984.0000\n", "Epoch [3/500] ERROR: 6704573.0000\n", "Epoch [4/500] ERROR: 6238314.0000\n", "Epoch [5/500] ERROR: 6090670.5000\n", "Epoch [6/500] ERROR: 6996567.0000\n", "Epoch [7/500] ERROR: 5883912.0000\n", "Epoch [8/500] ERROR: 6255249.0000\n", "Epoch [9/500] ERROR: 6302207.0000\n", "Epoch [10/500] ERROR: 6677468.0000\n", "Epoch [11/500] ERROR: 5846011.5000\n", "Epoch [12/500] ERROR: 5714164.5000\n", "Epoch [13/500] ERROR: 6616586.5000\n", "Epoch [14/500] ERROR: 5457780.5000\n", "Epoch [15/500] ERROR: 6271841.0000\n", "Epoch [16/500] ERROR: 6062314.5000\n", "Epoch [17/500] ERROR: 5327139.0000\n", "Epoch [18/500] ERROR: 5645721.5000\n", "Epoch [19/500] ERROR: 5864180.5000\n", "Epoch [20/500] ERROR: 6147806.5000\n", "Epoch [21/500] ERROR: 5998652.5000\n", "Epoch [22/500] ERROR: 5952885.0000\n", "Epoch [23/500] ERROR: 6143458.5000\n", "Epoch [24/500] ERROR: 6348583.0000\n", "Epoch [25/500] ERROR: 5830891.5000\n", "Epoch [26/500] ERROR: 5480641.0000\n", "Epoch [27/500] ERROR: 4948758.5000\n", "Epoch [28/500] ERROR: 5262505.0000\n", "Epoch [29/500] ERROR: 5594688.0000\n", "Epoch [30/500] ERROR: 5336407.5000\n", "Epoch [31/500] ERROR: 5717512.0000\n", "Epoch [32/500] ERROR: 5525984.0000\n", "Epoch [33/500] ERROR: 5889963.0000\n", "Epoch [34/500] ERROR: 5887103.0000\n", "Epoch [35/500] ERROR: 5426688.0000\n", "Epoch [36/500] ERROR: 5124869.5000\n", "Epoch [37/500] ERROR: 5618259.5000\n", "Epoch [38/500] ERROR: 5319918.0000\n", "Epoch [39/500] ERROR: 5731947.5000\n", "Epoch [40/500] ERROR: 5002415.0000\n", "Epoch [41/500] ERROR: 5978730.0000\n", "Epoch [42/500] ERROR: 4727743.0000\n", "Epoch [43/500] ERROR: 4799430.5000\n", "Epoch [44/500] ERROR: 4954430.5000\n", "Epoch [45/500] ERROR: 5633393.0000\n", "Epoch [46/500] ERROR: 5468525.0000\n", "Epoch [47/500] ERROR: 5349197.5000\n", "Epoch [48/500] ERROR: 5891797.5000\n", "Epoch [49/500] ERROR: 4614338.0000\n", "Epoch [50/500] ERROR: 5686690.5000\n", "Epoch [51/500] ERROR: 5210380.5000\n", "Epoch [52/500] ERROR: 5519909.5000\n", "Epoch [53/500] ERROR: 5204486.5000\n", "Epoch [54/500] ERROR: 4688049.0000\n", "Epoch [55/500] ERROR: 5066207.0000\n", "Epoch [56/500] ERROR: 4461897.0000\n", "Epoch [57/500] ERROR: 5060629.5000\n", "Epoch [58/500] ERROR: 5011283.5000\n", "Epoch [59/500] ERROR: 5258343.0000\n", "Epoch [60/500] ERROR: 4620037.5000\n", "Epoch [61/500] ERROR: 4955219.5000\n", "Epoch [62/500] ERROR: 4547958.5000\n", "Epoch [63/500] ERROR: 5162017.0000\n", "Epoch [64/500] ERROR: 4677858.0000\n", "Epoch [65/500] ERROR: 4374902.5000\n", "Epoch [66/500] ERROR: 4801792.5000\n", "Epoch [67/500] ERROR: 5135163.0000\n", "Epoch [68/500] ERROR: 4239194.5000\n", "Epoch [69/500] ERROR: 5207520.0000\n", "Epoch [70/500] ERROR: 4862048.5000\n", "Epoch [71/500] ERROR: 4522576.0000\n", "Epoch [72/500] ERROR: 4495291.0000\n", "Epoch [73/500] ERROR: 4647450.5000\n", "Epoch [74/500] ERROR: 4814331.5000\n", "Epoch [75/500] ERROR: 4081368.5000\n", "Epoch [76/500] ERROR: 4239002.5000\n", "Epoch [77/500] ERROR: 4198702.5000\n", "Epoch [78/500] ERROR: 4114028.5000\n", "Epoch [79/500] ERROR: 4368050.5000\n", "Epoch [80/500] ERROR: 3847192.0000\n", "Epoch [81/500] ERROR: 3622285.2500\n", "Epoch [82/500] ERROR: 4405690.0000\n", "Epoch [83/500] ERROR: 3988158.7500\n", "Epoch [84/500] ERROR: 4139869.0000\n", "Epoch [85/500] ERROR: 4336219.5000\n", "Epoch [86/500] ERROR: 3795467.5000\n", "Epoch [87/500] ERROR: 4052139.2500\n", "Epoch [88/500] ERROR: 3817265.7500\n", "Epoch [89/500] ERROR: 4003637.2500\n", "Epoch [90/500] ERROR: 4146160.5000\n", "Epoch [91/500] ERROR: 4133907.5000\n", "Epoch [92/500] ERROR: 3885572.0000\n", "Epoch [93/500] ERROR: 4407370.5000\n", "Epoch [94/500] ERROR: 3898665.7500\n", "Epoch [95/500] ERROR: 4103945.2500\n", "Epoch [96/500] ERROR: 4000358.7500\n", "Epoch [97/500] ERROR: 3580763.0000\n", "Epoch [98/500] ERROR: 3563694.7500\n", "Epoch [99/500] ERROR: 3692401.2500\n", "Epoch [100/500] ERROR: 3897245.7500\n", "Epoch [101/500] ERROR: 3417812.0000\n", "Epoch [102/500] ERROR: 3834390.2500\n", "Epoch [103/500] ERROR: 3305647.5000\n", "Epoch [104/500] ERROR: 3647353.2500\n", "Epoch [105/500] ERROR: 3403613.2500\n", "Epoch [106/500] ERROR: 3477037.0000\n", "Epoch [107/500] ERROR: 3497095.2500\n", "Epoch [108/500] ERROR: 3038645.0000\n", "Epoch [109/500] ERROR: 3668238.0000\n", "Epoch [110/500] ERROR: 3467860.5000\n", "Epoch [111/500] ERROR: 3298832.5000\n", "Epoch [112/500] ERROR: 3294130.5000\n", "Epoch [113/500] ERROR: 2911191.5000\n", "Epoch [114/500] ERROR: 3373069.2500\n", "Epoch [115/500] ERROR: 2917577.7500\n", "Epoch [116/500] ERROR: 3115044.0000\n", "Epoch [117/500] ERROR: 2840279.2500\n", "Epoch [118/500] ERROR: 2990629.2500\n", "Epoch [119/500] ERROR: 3185098.7500\n", "Epoch [120/500] ERROR: 2996723.5000\n", "Epoch [121/500] ERROR: 2907819.2500\n", "Epoch [122/500] ERROR: 2972324.5000\n", "Epoch [123/500] ERROR: 2628630.0000\n", "Epoch [124/500] ERROR: 2751936.7500\n", "Epoch [125/500] ERROR: 2713021.0000\n", "Epoch [126/500] ERROR: 2831771.5000\n", "Epoch [127/500] ERROR: 2816894.5000\n", "Epoch [128/500] ERROR: 2680418.2500\n", "Epoch [129/500] ERROR: 2878220.2500\n", "Epoch [130/500] ERROR: 2682823.2500\n", "Epoch [131/500] ERROR: 2660891.5000\n", "Epoch [132/500] ERROR: 2589765.5000\n", "Epoch [133/500] ERROR: 2566530.5000\n", "Epoch [134/500] ERROR: 2926894.0000\n", "Epoch [135/500] ERROR: 2413279.5000\n", "Epoch [136/500] ERROR: 2634831.5000\n", "Epoch [137/500] ERROR: 2645443.5000\n", "Epoch [138/500] ERROR: 2274955.2500\n", "Epoch [139/500] ERROR: 2585796.0000\n", "Epoch [140/500] ERROR: 2434393.7500\n", "Epoch [141/500] ERROR: 2362910.7500\n", "Epoch [142/500] ERROR: 2444359.2500\n", "Epoch [143/500] ERROR: 2472293.2500\n", "Epoch [144/500] ERROR: 2397576.2500\n", "Epoch [145/500] ERROR: 2062987.6250\n", "Epoch [146/500] ERROR: 2226736.0000\n", "Epoch [147/500] ERROR: 1983626.2500\n", "Epoch [148/500] ERROR: 2130679.5000\n", "Epoch [149/500] ERROR: 1985281.7500\n", "Epoch [150/500] ERROR: 2075771.6250\n", "Epoch [151/500] ERROR: 2090062.2500\n", "Epoch [152/500] ERROR: 2065924.8750\n", "Epoch [153/500] ERROR: 2038194.0000\n", "Epoch [154/500] ERROR: 2042899.7500\n", "Epoch [155/500] ERROR: 1768567.3750\n", "Epoch [156/500] ERROR: 1832315.3750\n", "Epoch [157/500] ERROR: 1869377.7500\n", "Epoch [158/500] ERROR: 1839572.1250\n", "Epoch [159/500] ERROR: 1801885.8750\n", "Epoch [160/500] ERROR: 1696655.8750\n", "Epoch [161/500] ERROR: 1791596.6250\n", "Epoch [162/500] ERROR: 1606142.5000\n", "Epoch [163/500] ERROR: 1782336.6250\n", "Epoch [164/500] ERROR: 1660405.2500\n", "Epoch [165/500] ERROR: 1557330.7500\n", "Epoch [166/500] ERROR: 1644607.3750\n", "Epoch [167/500] ERROR: 1480301.1250\n", "Epoch [168/500] ERROR: 1579116.1250\n", "Epoch [169/500] ERROR: 1593273.6250\n", "Epoch [170/500] ERROR: 1479812.5000\n", "Epoch [171/500] ERROR: 1496652.8750\n", "Epoch [172/500] ERROR: 1408004.1250\n", "Epoch [173/500] ERROR: 1262097.2500\n", "Epoch [174/500] ERROR: 1535174.0000\n", "Epoch [175/500] ERROR: 1278533.6250\n", "Epoch [176/500] ERROR: 1490654.6250\n", "Epoch [177/500] ERROR: 1283211.3750\n", "Epoch [178/500] ERROR: 1330647.7500\n", "Epoch [179/500] ERROR: 1243505.1250\n", "Epoch [180/500] ERROR: 1270135.7500\n", "Epoch [181/500] ERROR: 1208080.6250\n", "Epoch [182/500] ERROR: 1205707.5000\n", "Epoch [183/500] ERROR: 1099897.5000\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch [184/500] ERROR: 1145074.3750\n", "Epoch [185/500] ERROR: 1200663.1250\n", "Epoch [186/500] ERROR: 1027273.3125\n", "Epoch [187/500] ERROR: 1017164.0000\n", "Epoch [188/500] ERROR: 1138254.0000\n", "Epoch [189/500] ERROR: 1067482.8750\n", "Epoch [190/500] ERROR: 1040549.4375\n", "Epoch [191/500] ERROR: 981044.6875\n", "Epoch [192/500] ERROR: 954344.1250\n", "Epoch [193/500] ERROR: 974051.2500\n", "Epoch [194/500] ERROR: 934773.6875\n", "Epoch [195/500] ERROR: 896630.0000\n", "Epoch [196/500] ERROR: 864238.3125\n", "Epoch [197/500] ERROR: 842936.0000\n", "Epoch [198/500] ERROR: 793079.0625\n", "Epoch [199/500] ERROR: 906532.8125\n", "Epoch [200/500] ERROR: 803527.5625\n", "Epoch [201/500] ERROR: 715678.0000\n", "Epoch [202/500] ERROR: 780468.0625\n", "Epoch [203/500] ERROR: 760718.8750\n", "Epoch [204/500] ERROR: 705542.8750\n", "Epoch [205/500] ERROR: 770902.8125\n", "Epoch [206/500] ERROR: 707690.1875\n", "Epoch [207/500] ERROR: 698636.5625\n", "Epoch [208/500] ERROR: 771246.8750\n", "Epoch [209/500] ERROR: 636901.6875\n", "Epoch [210/500] ERROR: 681576.4375\n", "Epoch [211/500] ERROR: 692021.8125\n", "Epoch [212/500] ERROR: 582810.9375\n", "Epoch [213/500] ERROR: 600568.8750\n", "Epoch [214/500] ERROR: 592542.0000\n", "Epoch [215/500] ERROR: 623628.6875\n", "Epoch [216/500] ERROR: 597973.8125\n", "Epoch [217/500] ERROR: 555504.6875\n", "Epoch [218/500] ERROR: 533505.6250\n", "Epoch [219/500] ERROR: 572521.6250\n", "Epoch [220/500] ERROR: 488955.2188\n", "Epoch [221/500] ERROR: 526641.1250\n", "Epoch [222/500] ERROR: 484473.1250\n", "Epoch [223/500] ERROR: 501777.6562\n", "Epoch [224/500] ERROR: 458777.6562\n", "Epoch [225/500] ERROR: 422624.5938\n", "Epoch [226/500] ERROR: 408265.2188\n", "Epoch [227/500] ERROR: 428082.8750\n", "Epoch [228/500] ERROR: 426105.8438\n", "Epoch [229/500] ERROR: 439587.5938\n", "Epoch [230/500] ERROR: 426362.3438\n", "Epoch [231/500] ERROR: 408236.9062\n", "Epoch [232/500] ERROR: 381982.1562\n", "Epoch [233/500] ERROR: 355936.9688\n", "Epoch [234/500] ERROR: 386530.4375\n", "Epoch [235/500] ERROR: 329656.5000\n", "Epoch [236/500] ERROR: 342706.6250\n", "Epoch [237/500] ERROR: 339265.7812\n", "Epoch [238/500] ERROR: 360974.4062\n", "Epoch [239/500] ERROR: 316151.9375\n", "Epoch [240/500] ERROR: 310152.2500\n", "Epoch [241/500] ERROR: 320393.9375\n", "Epoch [242/500] ERROR: 311855.7500\n", "Epoch [243/500] ERROR: 285503.1562\n", "Epoch [244/500] ERROR: 290572.7812\n", "Epoch [245/500] ERROR: 282856.0625\n", "Epoch [246/500] ERROR: 260234.7031\n", "Epoch [247/500] ERROR: 260931.2812\n", "Epoch [248/500] ERROR: 244362.5000\n", "Epoch [249/500] ERROR: 264910.5625\n", "Epoch [250/500] ERROR: 234974.1094\n", "Epoch [251/500] ERROR: 245112.7188\n", "Epoch [252/500] ERROR: 223209.3125\n", "Epoch [253/500] ERROR: 228371.0000\n", "Epoch [254/500] ERROR: 233212.4531\n", "Epoch [255/500] ERROR: 211591.7812\n", "Epoch [256/500] ERROR: 206578.7188\n", "Epoch [257/500] ERROR: 206144.0781\n", "Epoch [258/500] ERROR: 185789.3594\n", "Epoch [259/500] ERROR: 208333.9219\n", "Epoch [260/500] ERROR: 179738.6719\n", "Epoch [261/500] ERROR: 187828.3281\n", "Epoch [262/500] ERROR: 179257.3750\n", "Epoch [263/500] ERROR: 180792.1562\n", "Epoch [264/500] ERROR: 159553.0938\n", "Epoch [265/500] ERROR: 154561.6094\n", "Epoch [266/500] ERROR: 155626.1406\n", "Epoch [267/500] ERROR: 157717.1094\n", "Epoch [268/500] ERROR: 158556.1094\n", "Epoch [269/500] ERROR: 151389.3281\n", "Epoch [270/500] ERROR: 159563.8906\n", "Epoch [271/500] ERROR: 135567.4688\n", "Epoch [272/500] ERROR: 141261.4688\n", "Epoch [273/500] ERROR: 137841.2500\n", "Epoch [274/500] ERROR: 137810.7969\n", "Epoch [275/500] ERROR: 119899.3359\n", "Epoch [276/500] ERROR: 126941.0156\n", "Epoch [277/500] ERROR: 112686.9141\n", "Epoch [278/500] ERROR: 111290.4844\n", "Epoch [279/500] ERROR: 118464.0859\n", "Epoch [280/500] ERROR: 111350.2422\n", "Epoch [281/500] ERROR: 102941.2344\n", "Epoch [282/500] ERROR: 97419.1172\n", "Epoch [283/500] ERROR: 105180.4297\n", "Epoch [284/500] ERROR: 102467.0156\n", "Epoch [285/500] ERROR: 100131.0547\n", "Epoch [286/500] ERROR: 93374.6094\n", "Epoch [287/500] ERROR: 93758.2031\n", "Epoch [288/500] ERROR: 92716.0078\n", "Epoch [289/500] ERROR: 96731.4922\n", "Epoch [290/500] ERROR: 89522.8047\n", "Epoch [291/500] ERROR: 85883.3594\n", "Epoch [292/500] ERROR: 85192.2656\n", "Epoch [293/500] ERROR: 84716.6016\n", "Epoch [294/500] ERROR: 75366.1250\n", "Epoch [295/500] ERROR: 71328.9297\n", "Epoch [296/500] ERROR: 71688.8203\n", "Epoch [297/500] ERROR: 75521.8047\n", "Epoch [298/500] ERROR: 77471.4531\n", "Epoch [299/500] ERROR: 71983.1641\n", "Epoch [300/500] ERROR: 68530.8203\n", "Epoch [301/500] ERROR: 67069.9922\n", "Epoch [302/500] ERROR: 66144.4766\n", "Epoch [303/500] ERROR: 59470.3633\n", "Epoch [304/500] ERROR: 59644.5000\n", "Epoch [305/500] ERROR: 64742.1172\n", "Epoch [306/500] ERROR: 52621.8672\n", "Epoch [307/500] ERROR: 55709.2070\n", "Epoch [308/500] ERROR: 61200.7227\n", "Epoch [309/500] ERROR: 54309.5742\n", "Epoch [310/500] ERROR: 51374.5781\n", "Epoch [311/500] ERROR: 52113.6445\n", "Epoch [312/500] ERROR: 47893.7383\n", "Epoch [313/500] ERROR: 43669.3125\n", "Epoch [314/500] ERROR: 48984.2695\n", "Epoch [315/500] ERROR: 44717.7383\n", "Epoch [316/500] ERROR: 47513.7031\n", "Epoch [317/500] ERROR: 46554.0156\n", "Epoch [318/500] ERROR: 46030.2930\n", "Epoch [319/500] ERROR: 44909.3945\n", "Epoch [320/500] ERROR: 47998.7422\n", "Epoch [321/500] ERROR: 41827.2656\n", "Epoch [322/500] ERROR: 40405.9102\n", "Epoch [323/500] ERROR: 40138.8633\n", "Epoch [324/500] ERROR: 40198.7305\n", "Epoch [325/500] ERROR: 39299.2734\n", "Epoch [326/500] ERROR: 36208.8086\n", "Epoch [327/500] ERROR: 34550.8125\n", "Epoch [328/500] ERROR: 32144.3750\n", "Epoch [329/500] ERROR: 32913.3867\n", "Epoch [330/500] ERROR: 32688.7285\n", "Epoch [331/500] ERROR: 35045.1797\n", "Epoch [332/500] ERROR: 33623.0195\n", "Epoch [333/500] ERROR: 32091.8086\n", "Epoch [334/500] ERROR: 29882.1641\n", "Epoch [335/500] ERROR: 29420.9727\n", "Epoch [336/500] ERROR: 30910.3301\n", "Epoch [337/500] ERROR: 32168.6523\n", "Epoch [338/500] ERROR: 27605.6973\n", "Epoch [339/500] ERROR: 27904.7148\n", "Epoch [340/500] ERROR: 25386.0742\n", "Epoch [341/500] ERROR: 29568.9023\n", "Epoch [342/500] ERROR: 29629.2402\n", "Epoch [343/500] ERROR: 25195.4375\n", "Epoch [344/500] ERROR: 24174.0312\n", "Epoch [345/500] ERROR: 23399.0059\n", "Epoch [346/500] ERROR: 25463.6660\n", "Epoch [347/500] ERROR: 23639.7949\n", "Epoch [348/500] ERROR: 25230.7598\n", "Epoch [349/500] ERROR: 21317.8340\n", "Epoch [350/500] ERROR: 24193.3066\n", "Epoch [351/500] ERROR: 23996.4180\n", "Epoch [352/500] ERROR: 23553.6836\n", "Epoch [353/500] ERROR: 22049.5566\n", "Epoch [354/500] ERROR: 23816.1660\n", "Epoch [355/500] ERROR: 22430.0762\n", "Epoch [356/500] ERROR: 19893.5566\n", "Epoch [357/500] ERROR: 23059.1211\n", "Epoch [358/500] ERROR: 19673.8496\n", "Epoch [359/500] ERROR: 20735.6836\n", "Epoch [360/500] ERROR: 20687.4082\n", "Epoch [361/500] ERROR: 20723.2852\n", "Epoch [362/500] ERROR: 22920.3848\n", "Epoch [363/500] ERROR: 20310.4824\n", "Epoch [364/500] ERROR: 20269.6348\n", "Epoch [365/500] ERROR: 20029.4199\n", "Epoch [366/500] ERROR: 18959.4160\n", "Epoch [367/500] ERROR: 15988.7158\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch [368/500] ERROR: 19006.8633\n", "Epoch [369/500] ERROR: 18354.0625\n", "Epoch [370/500] ERROR: 15863.4668\n", "Epoch [371/500] ERROR: 15994.8125\n", "Epoch [372/500] ERROR: 18140.0410\n", "Epoch [373/500] ERROR: 16175.2783\n", "Epoch [374/500] ERROR: 15778.8213\n", "Epoch [375/500] ERROR: 17274.9688\n", "Epoch [376/500] ERROR: 18792.8652\n", "Epoch [377/500] ERROR: 16362.4932\n", "Epoch [378/500] ERROR: 15920.2881\n", "Epoch [379/500] ERROR: 18069.3418\n", "Epoch [380/500] ERROR: 16469.9863\n", "Epoch [381/500] ERROR: 17217.3652\n", "Epoch [382/500] ERROR: 18992.1035\n", "Epoch [383/500] ERROR: 16638.7578\n", "Epoch [384/500] ERROR: 13901.7793\n", "Epoch [385/500] ERROR: 17286.6387\n", "Epoch [386/500] ERROR: 17399.2402\n", "Epoch [387/500] ERROR: 14857.7471\n", "Epoch [388/500] ERROR: 16751.0742\n", "Epoch [389/500] ERROR: 16263.5508\n", "Epoch [390/500] ERROR: 16098.8730\n", "Epoch [391/500] ERROR: 18780.7363\n", "Epoch [392/500] ERROR: 11355.0557\n", "Epoch [393/500] ERROR: 14476.4258\n", "Epoch [394/500] ERROR: 13582.9580\n", "Epoch [395/500] ERROR: 14740.5000\n", "Epoch [396/500] ERROR: 15038.8926\n", "Epoch [397/500] ERROR: 11870.3057\n", "Epoch [398/500] ERROR: 16520.7773\n", "Epoch [399/500] ERROR: 12342.9258\n", "Epoch [400/500] ERROR: 12574.6670\n", "Epoch [401/500] ERROR: 13184.3564\n", "Epoch [402/500] ERROR: 14563.4814\n", "Epoch [403/500] ERROR: 11666.1230\n", "Epoch [404/500] ERROR: 14378.5488\n", "Epoch [405/500] ERROR: 11100.2324\n", "Epoch [406/500] ERROR: 16029.1807\n", "Epoch [407/500] ERROR: 11934.1240\n", "Epoch [408/500] ERROR: 11719.6719\n", "Epoch [409/500] ERROR: 15555.9531\n", "Epoch [410/500] ERROR: 7600.9790\n", "Epoch [411/500] ERROR: 16091.4111\n", "Epoch [412/500] ERROR: 10949.3066\n", "Epoch [413/500] ERROR: 15417.2656\n", "Epoch [414/500] ERROR: 14068.3936\n", "Epoch [415/500] ERROR: 12099.4199\n", "Epoch [416/500] ERROR: 9536.8799\n", "Epoch [417/500] ERROR: 10579.9307\n", "Epoch [418/500] ERROR: 13088.9912\n", "Epoch [419/500] ERROR: 9846.8281\n", "Epoch [420/500] ERROR: 14895.1133\n", "Epoch [421/500] ERROR: 12326.7256\n", "Epoch [422/500] ERROR: 10120.9619\n", "Epoch [423/500] ERROR: 9871.9512\n", "Epoch [424/500] ERROR: 9565.8350\n", "Epoch [425/500] ERROR: 12096.2441\n", "Epoch [426/500] ERROR: 10889.2637\n", "Epoch [427/500] ERROR: 10154.0088\n", "Epoch [428/500] ERROR: 12802.4551\n", "Epoch [429/500] ERROR: 13417.0449\n", "Epoch [430/500] ERROR: 10324.2471\n", "Epoch [431/500] ERROR: 16851.1602\n", "Epoch [432/500] ERROR: 11297.0391\n", "Epoch [433/500] ERROR: 16179.0156\n", "Epoch [434/500] ERROR: 14947.5596\n", "Epoch [435/500] ERROR: 9119.6279\n", "Epoch [436/500] ERROR: 15094.3887\n", "Epoch [437/500] ERROR: 12299.0195\n", "Epoch [438/500] ERROR: 12579.1543\n", "Epoch [439/500] ERROR: 13010.0332\n", "Epoch [440/500] ERROR: 11936.0000\n", "Epoch [441/500] ERROR: 9911.3213\n", "Epoch [442/500] ERROR: 12044.9014\n", "Epoch [443/500] ERROR: 10288.1484\n", "Epoch [444/500] ERROR: 13719.3369\n", "Epoch [445/500] ERROR: 15717.6758\n", "Epoch [446/500] ERROR: 9881.3633\n", "Epoch [447/500] ERROR: 11272.0215\n", "Epoch [448/500] ERROR: 13930.7822\n", "Epoch [449/500] ERROR: 12136.0312\n", "Epoch [450/500] ERROR: 9280.6611\n", "Epoch [451/500] ERROR: 10968.5283\n", "Epoch [452/500] ERROR: 9680.3965\n", "Epoch [453/500] ERROR: 14612.3770\n", "Epoch [454/500] ERROR: 9187.3857\n", "Epoch [455/500] ERROR: 8361.5625\n", "Epoch [456/500] ERROR: 10765.2852\n", "Epoch [457/500] ERROR: 13021.6816\n", "Epoch [458/500] ERROR: 17085.4590\n", "Epoch [459/500] ERROR: 11360.2471\n", "Epoch [460/500] ERROR: 10376.6377\n", "Epoch [461/500] ERROR: 9064.3984\n", "Epoch [462/500] ERROR: 8998.8252\n", "Epoch [463/500] ERROR: 15337.5420\n", "Epoch [464/500] ERROR: 15993.6787\n", "Epoch [465/500] ERROR: 12342.5645\n", "Epoch [466/500] ERROR: 11282.2852\n", "Epoch [467/500] ERROR: 6347.4351\n", "Epoch [468/500] ERROR: 10288.6270\n", "Epoch [469/500] ERROR: 9583.7383\n", "Epoch [470/500] ERROR: 12725.2607\n", "Epoch [471/500] ERROR: 14775.3506\n", "Epoch [472/500] ERROR: 13469.9736\n", "Epoch [473/500] ERROR: 12204.7705\n", "Epoch [474/500] ERROR: 7863.4150\n", "Epoch [475/500] ERROR: 9599.0977\n", "Epoch [476/500] ERROR: 9666.9912\n", "Epoch [477/500] ERROR: 7404.3384\n", "Epoch [478/500] ERROR: 13729.2002\n", "Epoch [479/500] ERROR: 10838.1768\n", "Epoch [480/500] ERROR: 8060.2915\n", "Epoch [481/500] ERROR: 14313.9912\n", "Epoch [482/500] ERROR: 10251.9531\n", "Epoch [483/500] ERROR: 9581.0205\n", "Epoch [484/500] ERROR: 13795.7041\n", "Epoch [485/500] ERROR: 15384.6182\n", "Epoch [486/500] ERROR: 14782.8408\n", "Epoch [487/500] ERROR: 6769.0591\n", "Epoch [488/500] ERROR: 12082.3740\n", "Epoch [489/500] ERROR: 7909.1895\n", "Epoch [490/500] ERROR: 16194.1133\n", "Epoch [491/500] ERROR: 10340.7002\n", "Epoch [492/500] ERROR: 10477.5283\n", "Epoch [493/500] ERROR: 10792.0986\n", "Epoch [494/500] ERROR: 10765.4189\n", "Epoch [495/500] ERROR: 14855.5664\n", "Epoch [496/500] ERROR: 10105.3105\n", "Epoch [497/500] ERROR: 13815.2207\n", "Epoch [498/500] ERROR: 7370.5957\n", "Epoch [499/500] ERROR: 10001.3135\n" ] } ], "source": [ "\n", "for epoch in range(NUM_EPOCHS):\n", " model.train()\n", " for batch_idx, (X, y) in enumerate(dataloader):\n", " \n", " X, y = X.to(device), y.to(device)\n", " pred = model(X)\n", " loss = criterion(y, pred)\n", " \n", " model.zero_grad()\n", " loss.backward()\n", " opt.step()\n", " print(\n", " f\"Epoch [{epoch}/{NUM_EPOCHS}] \\\n", " ERROR: {loss:.4f}\"\n", " )\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "d3e7af32", "metadata": {}, "outputs": [], "source": [ "pred = pd.DataFrame(model(torch.Tensor(X_test.values).to(device)), columns=y_test.columns, index=y_test.index)" ] }, { "cell_type": "code", "execution_count": 7, "id": "e670c096", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "rgnp_0 160.920304\n", "rgnp_1 187.063986\n", "rgnp_2 185.027259\n", "rgnp_3 172.836946\n", "rgnp_4 157.783776\n", "rgnp_5 121.775486\n", "rgnp_6 124.746922\n", "rgnp_7 139.684623\n", "rgnp_8 76.782863\n", "rgnp_9 63.746731\n", "rgnp_10 60.439411\n", "rgnp_11 54.715396\n", "dtype: float64" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "ac9b7e94", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "c1e69879", "metadata": {}, "outputs": [], "source": [ " ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))\n", " disc_real = disc(real).reshape(-1)\n", " # 设置real=1, fake=0\n", " # disc(real)的target是[1,1,1,1,...]\n", " loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))\n", " \n", " \n", " disc_fake = disc(fake.detach()).reshape(-1)\n", " # disc(fake)的target是(0,0,0,0,...)\n", " loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))\n", " \n", " # 损失是两者的均值\n", " loss_disc = (loss_disc_real + loss_disc_fake) / 2\n", " \n", "\n", " \n", " ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))\n", " output = disc(fake).reshape(-1)\n", " # fake希望disc(fake)预测的结果=1\n", " loss_gen = criterion(output, torch.ones_like(output))\n", " gen.zero_grad()\n", " loss_gen.backward()\n", " opt_gen.step()\n", " # 每100个batch记录一次模型效果\n", " if count % 100 == 0:\n", " print(\n", " f\"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \\\n", " Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}\"\n", " )\n", " \n", " with torch.no_grad():\n", " fake = gen(fixed_noise)\n", " # take out (up to) 32 examples\n", " img_grid_real = torchvision.utils.make_grid(\n", " real[:32], normalize=True\n", " )\n", " img_grid_fake = torchvision.utils.make_grid(\n", " fake[:32], normalize=True\n", " )\n", "\n", " writer_real.add_image(\"Real\", img_grid_real, global_step=step)\n", " writer_fake.add_image(\"Fake\", img_grid_fake, global_step=step)\n", " writer_real.flush()\n", " writer_fake.flush()\n", " step += 1\n", " count += 1" ] } ], "metadata": { "kernelspec": { "display_name": "conda_pytorch_p38", "language": "python", "name": "conda_pytorch_p38" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 5 }