CatBoost on Google Cloud’s AI Platform w/ CPUs & GPUs


  • CatBoost, an OSS gradient boosting framework , is the new kid on the block with intriguing benchmark results on model quality and training/serving speeds with CPUs & GPUs
  • Catboost is simple to run on Google Cloud’s AI Platform for both the Notebook & Training services. GitHub examples to accompany this post HERE
  • While AI Platform Training only lists TensorFlow, XGBoost, & SKLearn as officially supported hosted frameworks, CatBoost works seamlessly by including it in the file during training job submission

CatBoost for Gradient Boosting

My focus the last few years has been heavily skewed towards deep learning & neural networks but this weekend I wanted to quickly catch up on the world of gradient boosting on structured data. The biggest item I seem to have missed over the last 2.5 years is the launch and growth of CatBoost.

  • OSS ML framework from Russia developed by Yandex researchers & engineers. Used as part of key Yandex services including search, weather, and personal assistant
  • Model Quality: Better accuracy on several benchmarks vs. XGBoost, LightGBM, & H2O
  • Speed: Faster training & serving times on CPUs & GPUs
  • Out-of-the-box Usability: Better model quality with default parameters, sophisticated categorical/text support and visual model analysis tools
  1. Run through an end-to-end example on AI Platform Notebooks using the dataset used in the XGBoost tutorial (This can also be run for free on Google Colab with a few changes)
  2. Put together a simple example of running CatBoost on AI Platform Training using both CPUs & GPUs

AI Platform Notebook

Notebook on GitHub HERE


  1. Create an AI Platform Notebook Instance with the ‘CUDA 10.1’ image with a Nvidia T4 GPU
  2. Open up the JupyterLab UI
  3. Open the Terminal and install CatBoost using ‘pip install catboost’. All other libraries (pandas, numpy, sklearn) and drivers (cuda, nccl) are already installed as part fo the base image
  4. Run through this notebook


  • CatBoost works with all your existing python & cuda libraries so setup was a breeze. It took me <5 min to get up and running
  • The CatBoost syntax and APIs are well designed and follow familiar patterns used by other libraries
  • GPUs were simple to use. I first used a prebuilt utility to check if the T4 GPU was being recognized.
from catboost.utils import get_gpu_device_count
print('I see %i GPU devices' % get_gpu_device_count())

AI Platform Training

Training Package on GitHub HERE


  1. Install the Google Cloud SDK on your development machine. If you are using Google Cloud Shell, Google Cloud Compute Engine, or the Terminal on the AI Platform Notebook instance in the previous section the SDK should already be pre-installed.
  2. Create a Google Cloud Storage bucket to save your model after the training is completed.
  3. Package the training code in the following directory and file structure:
root directory
-trainer directory
-scripts directory
  • file includes the same code from the Jupyter notebook example in the previous section
  • includes the AI Platform SDK command to submit the training job
  • file includes any python libraries to install on the on-demand server(s) that are not by default part of the hosted service. In this case you only need to specify ‘catboost’.
gcloud ai-platform jobs submit training $JOB_NAME \
--job-dir $JOB_DIR \
--package-path $TRAINING_PACKAGE_PATH \
--module-name $MAIN_TRAINER_MODULE \
--region $REGION \
--runtime-version=$RUNTIME_VERSION \
--python-version=$PYTHON_VERSION \
--scale-tier $SCALE_TIER


  • While AI Platform Training only lists TensorFlow, XGBoost, & SKLearn as officially supported hosted frameworks, CatBoost works seamlessly by including it in the file. You can also use custom containers to submit a job that has more complex runtime requirements.
  • TODO: Need to add in a few custom arguments to refactor the code to support both CPUs & GPUs automatically. Currently I have to manually comment in/out the GPU task_type in the code.

Future follow-ups

  • Serve the model for both batch and online needs
  • Play with tuning options to maximize model quality
  • Play with the built-in visualization and explainability tools