<i id='R0M5X'><tr id='R0M5X'><dt id='R0M5X'><q id='R0M5X'><span id='R0M5X'><b id='R0M5X'><form id='R0M5X'><ins id='R0M5X'></ins><ul id='R0M5X'></ul><sub id='R0M5X'></sub></form><legend id='R0M5X'></legend><bdo id='R0M5X'><pre id='R0M5X'><center id='R0M5X'></center></pre></bdo></b><th id='R0M5X'></th></span></q></dt></tr></i><div id='R0M5X'><tfoot id='R0M5X'></tfoot><dl id='R0M5X'><fieldset id='R0M5X'></fieldset></dl></div>
    • <bdo id='R0M5X'></bdo><ul id='R0M5X'></ul>
    <tfoot id='R0M5X'></tfoot>
  • <legend id='R0M5X'><style id='R0M5X'><dir id='R0M5X'><q id='R0M5X'></q></dir></style></legend>

  • <small id='R0M5X'></small><noframes id='R0M5X'>

        如何保存 GridSearchCV 对象?

        时间:2023-11-08
            <tbody id='OmxYg'></tbody>
            <bdo id='OmxYg'></bdo><ul id='OmxYg'></ul>
              1. <small id='OmxYg'></small><noframes id='OmxYg'>

                1. <legend id='OmxYg'><style id='OmxYg'><dir id='OmxYg'><q id='OmxYg'></q></dir></style></legend>

                  <tfoot id='OmxYg'></tfoot>

                  <i id='OmxYg'><tr id='OmxYg'><dt id='OmxYg'><q id='OmxYg'><span id='OmxYg'><b id='OmxYg'><form id='OmxYg'><ins id='OmxYg'></ins><ul id='OmxYg'></ul><sub id='OmxYg'></sub></form><legend id='OmxYg'></legend><bdo id='OmxYg'><pre id='OmxYg'><center id='OmxYg'></center></pre></bdo></b><th id='OmxYg'></th></span></q></dt></tr></i><div id='OmxYg'><tfoot id='OmxYg'></tfoot><dl id='OmxYg'><fieldset id='OmxYg'></fieldset></dl></div>

                  本文介绍了如何保存 GridSearchCV 对象?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着跟版网的小编来一起学习吧!

                  问题描述

                  最近,我一直致力于在带有 Tensorflow 后端的 Keras 中应用网格搜索交叉验证 (sklearn GridSearchCV) 进行超参数调整.我的模型调整好后我正在尝试保存 GridSearchCV 对象以供以后使用,但没有成功.

                  Lately, I have been working on applying grid search cross validation (sklearn GridSearchCV) for hyper-parameter tuning in Keras with Tensorflow backend. An soon as my model is tuned I am trying to save the GridSearchCV object for later use without success.

                  超参数调优如下:

                  x_train, x_val, y_train, y_val = train_test_split(NN_input, NN_target, train_size = 0.85, random_state = 4)
                  
                  history = History() 
                  kfold = 10
                  
                  
                  regressor = KerasRegressor(build_fn = create_keras_model, epochs = 100, batch_size=1000, verbose=1)
                  
                  neurons = np.arange(10,101,10) 
                  hidden_layers = [1,2]
                  optimizer = ['adam','sgd']
                  activation = ['relu'] 
                  dropout = [0.1] 
                  
                  parameters = dict(neurons = neurons,
                                    hidden_layers = hidden_layers,
                                    optimizer = optimizer,
                                    activation = activation,
                                    dropout = dropout)
                  
                  gs = GridSearchCV(estimator = regressor,
                                    param_grid = parameters,
                                    scoring='mean_squared_error',
                                    n_jobs = 1,
                                    cv = kfold,
                                    verbose = 3,
                                    return_train_score=True))
                  
                  grid_result = gs.fit(NN_input,
                                      NN_target,
                                      callbacks=[history],
                                      verbose=1,
                                      validation_data=(x_val, y_val))
                  

                  备注:create_keras_model 函数初始化并编译一个 Keras Sequential 模型.

                  Remark: create_keras_model function initializes and compiles a Keras Sequential model.

                  执行交叉验证后,我尝试使用以下代码保存网格搜索对象 (gs):

                  After the cross validation is performed I am trying to save the grid search object (gs) with the following code:

                  from sklearn.externals import joblib
                  
                  joblib.dump(gs, 'GS_obj.pkl')
                  

                  我得到的错误如下:

                  TypeError: can't pickle _thread.RLock objects
                  

                  能否请您告诉我此错误的原因可能是什么?

                  Could you please let me know what might be the reason for this error?

                  谢谢!

                  P.S.:joblib.dump 方法适用于保存使用的 GridSearchCV 对象用于训练来自 sklearn 的 MLPRegressors.

                  P.S.: joblib.dump method works well for saving GridSearchCV objects that are used for the training MLPRegressors from sklearn.

                  推荐答案

                  使用

                  直接导入joblib

                  而不是

                  从 sklearn.externals 导入作业库

                  保存对象或结果:

                  joblib.dump(gs, 'model_file_name.pkl')

                  并使用以下方法加载您的结果:

                  and load your results using:

                  joblib.load("model_file_name.pkl")

                  这是一个简单的工作示例:

                  Here is a simple working example:

                  
                  import joblib
                  
                  #save your model or results
                  joblib.dump(gs, 'model_file_name.pkl')
                  
                  #load your model for further usage
                  joblib.load("model_file_name.pkl")
                  
                  

                  这篇关于如何保存 GridSearchCV 对象?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持跟版网!

                  上一篇:使用新的“虚拟"保存基于类的视图表单集项目.柱子 下一篇:为什么“10000000000000000 在范围内(1000000000000001)"Python 3 这

                  相关文章

                    <bdo id='JZlBI'></bdo><ul id='JZlBI'></ul>

                    <i id='JZlBI'><tr id='JZlBI'><dt id='JZlBI'><q id='JZlBI'><span id='JZlBI'><b id='JZlBI'><form id='JZlBI'><ins id='JZlBI'></ins><ul id='JZlBI'></ul><sub id='JZlBI'></sub></form><legend id='JZlBI'></legend><bdo id='JZlBI'><pre id='JZlBI'><center id='JZlBI'></center></pre></bdo></b><th id='JZlBI'></th></span></q></dt></tr></i><div id='JZlBI'><tfoot id='JZlBI'></tfoot><dl id='JZlBI'><fieldset id='JZlBI'></fieldset></dl></div>

                      <small id='JZlBI'></small><noframes id='JZlBI'>

                      <tfoot id='JZlBI'></tfoot>

                    1. <legend id='JZlBI'><style id='JZlBI'><dir id='JZlBI'><q id='JZlBI'></q></dir></style></legend>