基于LSTM神经网络模型预测北京PM2.5排放量预测

Velika ·
更新时间:2024-09-21
· 983 次阅读

基于LSTM神经网络模型预测北京PM2.5排放量预测 代码实现 import tensorflow as tf import numpy as np import pandas as pd import matplotlib.pyplot as plt from tensorflow.keras import layers from sklearn.preprocessing import StandardScaler, MinMaxScaler, OneHotEncoder from sklearn.impute import SimpleImputer from sklearn.model_selection import train_test_split, cross_val_score from sklearn.metrics import r2_score, mean_squared_error import datetime data = pd.read_csv(r"E:\tensorflow2.0\tensorflow2.0_Code\datasets\PRSA_data_2010.1.1-2014.12.31.csv") ''' RangeIndex: 43824 entries, 0 to 43823 Data columns (total 13 columns): No 43824 non-null int64 year 43824 non-null int64 month 43824 non-null int64 day 43824 non-null int64 hour 43824 non-null int64 pm2.5 41757 non-null float64 DEWP 43824 non-null int64 TEMP 43824 non-null float64 PRES 43824 non-null float64 cbwd 43824 non-null object Iws 43824 non-null float64 Is 43824 non-null int64 Ir 43824 non-null int64 可以发现总共有43824条记录,缺失值的列有pm2.5 ''' # print(data.info()) # 数据的预处理 # 对pm2.5这一列的数据进行缺失值处理 imputer = SimpleImputer(missing_values=np.nan, strategy="mean") # 默认的方式是以均值的方式进行填补 data.loc[:, 'pm2.5'] = imputer.fit_transform(data.loc[:, 'pm2.5'].values.reshape(-1, 1)) # print(data.info()) ''' No year month day hour pm2.5 ... TEMP PRES cbwd Iws Is Ir 0 1 2010 1 1 0 98.613215 ... -11.0 1021.0 NW 1.79 0 0 1 2 2010 1 1 1 98.613215 ... -12.0 1020.0 NW 4.92 0 0 2 3 2010 1 1 2 98.613215 ... -11.0 1019.0 NW 6.71 0 0 3 4 2010 1 1 3 98.613215 ... -14.0 1019.0 NW 9.84 0 0 4 5 2010 1 1 4 98.613215 ... -12.0 1018.0 NW 12.97 0 0 可以看出,第一列是索引列是没有什么用的,可以删除掉 ''' # print(data.head()) # 删除No列 data.drop(columns=['No'], inplace=True) # print(data.info()) # 将时间 year month day hour 这几列进行合并 data['time'] = data.apply(lambda x: datetime.datetime(year=x['year'] , month=x['month'] , day=x['day'] , hour=x['hour']), axis=1) # 将time作为新的索引列 data.set_index('time', inplace=True) # 删除掉year month day hour这几列 data.drop(columns=['year', 'month', 'day', 'hour'], inplace=True) # pm2.5 DEWP TEMP PRES cbwd Iws Is Ir # time # 2010-01-01 00:00:00 98.613215 -21 -11.0 1021.0 NW 1.79 0 0 # print(data.head()) # 对cbwd 这一类进行编码onehot编码 cbwdtolist = data.loc[:, 'cbwd'].unique().tolist() # ['NW', 'cv', 'NE', 'SE'] # print(cbwdtolist) data = data.join(pd.get_dummies(data.cbwd)) data.drop(columns=['cbwd'], inplace=True) # print(data.head()) # (43824, 11) # print(data.shape) sequence_length = 5 * 24 delay = 24 data_ = [] for i in range(len(data) - sequence_length - delay): data_.append(data.iloc[i:i + sequence_length + delay]) # (144, 11) # print(data_[0].shape) # data_ 是一个列表,里面的每个元素都是一个 # print(type(data_[0])) data_ = np.array([df.values for df in data_]) # (43680, 144, 11) # print(data_.shape) X = data_[:, :-delay, :] y = data_[:, -1, 0] # (43680, 120, 11) # (43680,) # print(X.shape) # print(y.shape) Xtrain, Xtest, Ytrain, Ytest = train_test_split(X, y, test_size=0.3, shuffle=True) # (30576, 120, 11) # print(Xtrain.shape) # 进行标准化处理 mean = Xtrain.mean(axis=0) std = Xtrain.std(axis=0) Xtrain = (Xtrain - mean) / std # 标准化处理 Xtest = (Xtest - mean) / std batch_size = 128 # (30576, 120, 11) # (120, 11) # (13104, 120, 11) # (120, 11) # (30576,) # (13104,) # print(Xtrain.shape) # print(Xtrain.shape[1:]) # print(Xtest.shape) # print(Xtest.shape[1:]) # print(Ytrain.shape) # print(Ytest.shape) # 建立模型 model = tf.keras.Sequential() model.add(tf.keras.layers.LSTM(32, activation='tanh', input_shape=(Xtrain.shape[1:]), return_sequences=True)) model.add(tf.keras.layers.LSTM(32, return_sequences=True)) model.add(tf.keras.layers.LSTM(32, return_sequences=True)) model.add(tf.keras.layers.LSTM(32, return_sequences=False)) model.add(tf.keras.layers.Dropout(0.5)) model.add(tf.keras.layers.Dense(1, activation='linear')) # 编译模型 model.compile(loss='mean_squared_error', optimizer='adam') history = model.fit(Xtrain, Ytrain, batch_size=batch_size, epochs=200) # dict_keys(['loss']) # print(history.history.keys()) Y_predict = model.predict(Xtest) print(mean_squared_error(Ytest, Y_predict)) print(r2_score(Ytest, Y_predict)) itboy996 原创文章 16获赞 1访问量 414 关注 私信 展开阅读全文
作者:itboy996



模型 网络模型 lstm pm2

需要 登录 后方可回复, 如果你还没有账号请 注册新账号
相关文章