Prompt engineering sucks. Break free from the endless tweaking with this revolutionary approach  - Learn more

Securing AI systems is tricky, ignoring it is risky. Discover the easiest way to secure your AI end to end  - Learn more

Machine Learning

Production ML for Practitioners: How to Accelerate Model Training with LightGBM & Optuna

Ignacio Amaya de la Peña Ignacio Amaya de la Peña 8 min read Nov 26, 2023

Refer to Google Colab for code snippets.

When it comes to the world of data science, machine learning models usually get all the attention, but the real heroes of data science lie elsewhere.

In most data science projects the most time-consuming part is not the modeling, where we select the best training models, run tests with different training time frames, test different features to include or exclude from the model, and compare different model hyperparameters.

When we dive into the well-established CRISP-DM (Cross-Industry Standard Process for Data Mining) framework, we uncover a truth often overshadowed: modeling is not the star of the show.

The secret sauce of a successful project is found in the earlier phases—business understanding, data understanding, and data preparation. These stages, often underestimated, are where the true magic begins to unfold. They might not be as glamorous, but they are the foundation upon which successful data science projects are built. These steps are usually critical and take much longer time than modeling.

Evaluation and deployment are also important, but they usually come later in the project and have less uncertainty. If your problem is defined correctly and your data is in the right shape, then most likely the rest of the steps won’t be a deal-breaker.

Business understanding and Data Understanding
Image source

Having said that, once you get to the modeling step the risk of spending excessive time on training various models and hunting for the best hyperparameters remains a concern. Luckily, there are some ways that can help you accelerate your training and search code.

Scikit-learn is often the starting point for building your initial baseline models. Beginning with something simple, like linear or logistic regression, is a sensible choice to establish a benchmark. However, experiments can get messy quickly and they can also take a long time, so sometimes shortcuts are taken. For example, while 10-fold cross-validation is frequently recommended for robust evaluation, it can be quite time-consuming, so many experiments end up using fewer folds.

Even though modeling isn’t the most time-consuming step, no one wants to wait for hours while their hyperparameter search grid experiments are complete.

Understanding LightGBM & Optuna’s role in model training

That’s why opting for faster tools for these tasks can significantly cut down on the time spent in the modeling phase. LightGBM is a really popular library known for its tree-based learning algorithms. It’s both rapid and memory-efficient, all while delivering accurate results.

However, even with  LightGBM’s speed, big search grids can still consume a lot of time. That is where the power of Optuna comes into play. Optuna is a versatile hyperparameter optimization framework that supports efficient state-of-the-art algorithms to quickly find the best model parameters and conveniently visualize the results.

Just to showcase how you can use lightGBM and how much faster it is, we can compare it with the traditional way of training a tree-based model with scikit-learn.

One added benefit of lightGBM is that you do not need to one-hot-encode or transform your categorical features as those are supported by the model. However, if you have high-dimensional features it makes sense to only select the main categories or apply target encoding to transform them into numerical features.

Let’s pick an easy example for this. Let’s use the world-famous Titanic dataset for this comparison!

First, you can read the data, select the relevant columns, and create the training and test sets.

Refer to Google Colab: Code 1 – Data preparation Load the Titanic dataset and create the training and test sets

Then, we create a generic method that can train a model using the training set and evaluate it in the test set. We also add some time measurements for the training as we will use them to compare the time taken in training between lightGBM and scikit-learn.

Refer to Google Colab: Code 2 – Training and Evaluation

def train_and_predict(model: Any, x_train: pd.DataFrame, y_train: pd.DataFrame, x_test: pd.DataFrame, plot_cm=True) -> Dict[str, int]:
  start_time = time.time()
  # Train the model
  model.fit(x_train, y_train)
  end_time = time.time()
  elapsed_time = end_time - start_time

  # Make predictions on the test set
  y_pred = model.predict(x_test)
  accuracy = accuracy_score(y_test, y_pred)

  # Display Confusion Matrix
  if plot_cm:
    cm = confusion_matrix(y_test, y_pred)
    display(ConfusionMatrixDisplay(confusion_matrix=cm).plot(cmap=plt.cm.Blues))
  return {"model": model.__class__.__name__, "elapsed_training_time": elapsed_time, "accuracy": accuracy}

Performance Comparison

With that, we have everything we need to compare both model libraries. We can see that:

Refer to Google Colab: Code 3 – Performance tests with scikit-learn and ligthGBM

run_performance_tests(x_train=x_train, y_train=y_train, x_test=x_test)
The single decision tree is clearly faster, but we can already see a big difference between the RandomForest and the LGBM in terms of training time (almost 5 times faster). But this dataset is extremely small, let’s oversample it a bit to have some results with datasets more similar to real life.

We can see that LGBMClassifier is several orders of magnitude faster than RandomForestClassifier. Decision trees are still faster though. In this case, a simple decision tree does not perform so badly in terms of accuracy, but in general decision trees are not good enough for the majority of problems as they have a high variance.

The times above are also low because the Titanic dataset is very small. In reality, datasets usually have many more rows. That is why we can oversample the dataset to estimate more realistic times if the dataset has a higher number of rows. Below you can see results for datasets of around 350K rows:

run_performance_tests(x_train=x_train_oversampled, y_train=y_train_oversampled, x_test=x_test)

We can also see how the LGBMClassifier decreased its accuracy significantly with the oversampling because additional samples caused some overfitting.

Leveraging Optuna for advanced hyperparameter tuning

Let’s now try to improve the accuracy of those models by doing a hyperparameter search and measuring how much time it takes. Let’s start with a simple CV search grid with the Random Forest classifier. We explore the maximum depth and number of trees.

We explore 30 candidates and it takes 8 seconds, increasing the accuracy of the model when compared with our previous results.

Refer to Google Colab: Code 4 – Search grid Random Forest

%%time

random_forest = RandomForestClassifier(n_estimators=100, random_state=42)

# Define a parameter grid to search
param_grid = {
    'n_estimators': [20, 40, 60, 80, 100],  # Number of trees in the forest
    'max_depth': [None, 2, 4, 6, 8, 10],  # Maximum depth of each tree
    # 'min_samples_split': [2, 6, 10, 14],  # Minimum samples required to split a node
    # 'min_samples_leaf': [1, 4],    # Minimum samples required at each leaf node
}

# Create a GridSearchCV object
grid_search = GridSearchCV(estimator=random_forest, param_grid=param_grid, cv=2, n_jobs=-1, verbose=2)

# Fit the GridSearchCV object to the training data
grid_search.fit(x_train, y_train)

# Print the best hyperparameters found
print("Best hyperparameters:", grid_search.best_params_)

# Evaluate the model on the test set using the best hyperparameters
best_rf_classifier = grid_search.best_estimator_
test_accuracy = best_rf_classifier.score(x_test, y_test)
print("Test Accuracy:", test_accuracy)

This gives us an output of:

Now let’s use Optuna with the LGBMClassifier to also explore 30 candidates. However, in this case, we do not explicitly specify the values to explore. This has the advantage of potentially exploring a wider range of values facilitating the discovery of the optimal values. In this case, we tested the learning rate and the number of iterations, which are important in this model.

Refer to Google Colab: Code 5- Optuna search

In this specific example, the time taken is almost the same, but we have been able to explore a wider range of values. Adding more parameters to the search should not increase Optuna time dramatically as you control how many trials are run. However, adding more parameters to the search grid increases the time significantly, which leads to fixing some of the best parameters found to search the rest. Sometimes this works, but in some cases when you start changing other parameters the fixed ones are not optimal anymore.

Another nice thing Optuna provides is nice plots to understand the results. Below we can see some examples of the feature importance of the parameters searched that can be used to plan future trials.

Refer to Google Colab: Code 6- Optuna visualizations

We can also look at the contour plot to assess which combinations of parameters showed more potential.

Another nice thing is that these plots are also provided when having multiple hyperparameters, so it is much easier to understand at a glance our evaluation metric for the different parameters tested.

In this case, we used Optuna with lightGBM, but it could have been also used with the Random Forest model as it is model agnostic.

Final thought

So before you plan your next hyperparameter search, if you are planning to explore multiple parameters and your dataset has a significant number of rows, consider giving lightGBM and Optuna a try as that can save you lots of hours of waiting time!

Waiting for your model to train? Why not zap some aliens with Train Invaders in the meantime?

Want more production ML guides? Check out the Aporia blog or reach out with any questions.

Rate this article

Average rating 0 / 5. Vote count: 0

No votes so far! Be the first to rate this post.

On this page

Building an AI agent?

Consider AI Guardrails to get to production faster

Learn more

Related Articles