sklearn与LightGBM配合使用

Florida ·
更新时间:2024-09-21
· 633 次阅读

LightGBM建模,sklearn评估 # coding: utf-8 import lightgbm as lgb import pandas as pd from sklearn.metrics import mean_squared_error from sklearn.model_selection import GridSearchCV # 加载数据 print('加载数据...') df_train = pd.read_csv('./data/regression.train.txt', header=None, sep='\t') df_test = pd.read_csv('./data/regression.test.txt', header=None, sep='\t') # 取出特征和标签 y_train = df_train[0].values y_test = df_test[0].values X_train = df_train.drop(0, axis=1).values X_test = df_test.drop(0, axis=1).values print('开始训练...') # 直接初始化LGBMRegressor # 这个LightGBM的Regressor和sklearn中其他Regressor基本是一致的 gbm = lgb.LGBMRegressor(objective='regression', num_leaves=31, learning_rate=0.05, n_estimators=20) # 使用fit函数拟合 gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], eval_metric='l1', early_stopping_rounds=5) # 预测 print('开始预测...') y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration_) # 评估预测结果 print('预测结果的rmse是:') print(mean_squared_error(y_test, y_pred) ** 0.5) 网格搜索查找最优超参数 # 配合scikit-learn的网格搜索交叉验证选择最优超参数 estimator = lgb.LGBMRegressor(num_leaves=31) param_grid = { 'learning_rate': [0.01, 0.1, 1], 'n_estimators': [20, 40] } gbm = GridSearchCV(estimator, param_grid) gbm.fit(X_train, y_train) print('用网格搜索找到的最优超参数为:') print(gbm.best_params_) 绘图解释 # coding: utf-8 import lightgbm as lgb import pandas as pd try: import matplotlib.pyplot as plt except ImportError: raise ImportError('You need to install matplotlib for plotting.') # 加载数据集 print('加载数据...') df_train = pd.read_csv('./data/regression.train.txt', header=None, sep='\t') df_test = pd.read_csv('./data/regression.test.txt', header=None, sep='\t') # 取出特征和标签 y_train = df_train[0].values y_test = df_test[0].values X_train = df_train.drop(0, axis=1).values X_test = df_test.drop(0, axis=1).values # 构建lgb中的Dataset数据格式 lgb_train = lgb.Dataset(X_train, y_train) lgb_test = lgb.Dataset(X_test, y_test, reference=lgb_train) # 设定参数 params = { 'num_leaves': 5, 'metric': ('l1', 'l2'), 'verbose': 0 } evals_result = {} # to record eval results for plotting print('开始训练...') # 训练 gbm = lgb.train(params, lgb_train, num_boost_round=100, valid_sets=[lgb_train, lgb_test], feature_name=['f' + str(i + 1) for i in range(28)], categorical_feature=[21], evals_result=evals_result, verbose_eval=10) print('在训练过程中绘图...') ax = lgb.plot_metric(evals_result, metric='l1') plt.show() print('画出特征重要度...') ax = lgb.plot_importance(gbm, max_num_features=10) plt.show() print('画出第84颗树...') ax = lgb.plot_tree(gbm, tree_index=83, figsize=(20, 8), show_info=['split_gain']) plt.show() #print('用graphviz画出第84颗树...') #graph = lgb.create_tree_digraph(gbm, tree_index=83, name='Tree84') #graph.render(view=True)
作者:小菜鸡一号



lightgbm

需要 登录 后方可回复, 如果你还没有账号请 注册新账号