Developing custom scikit-learn transformers and estimators
scikit-learn offers a wide range of Machine Learning models, but it goes way beyond that by providing other tools such as hyperparameter optimization using GridSearchCV
or composed estimators via Pipeline
. One of the characteristics I like the most about scikit-learn is their consistent API, all estimators implement the same basic methods (fit and predict). This consistency has been immensely useful to the ML open source community since a lot of third party packages are developed with this in mind (e.g. Keras), hence they are able to interface with each other.
Often we need to implement some functionality that does not exist in scikit-learn or any other packages, if we conform to scikit-learn’s API we can limit ourselves to develop a custom transformer/estimator and our code will nicely interface with scikit-learn modules.
In this blog post, I will show how to build custom transformers and estimators, as well as discuss implementation details to do this correctly. The official docs contain all you need to know but here are the most important facts:
- All constructor (the
__init__
function) parameters should have default values - Constructor parameters should be added as attributes without any modifications
- Attributes estimated from data must have a name with a trailing underscore
There are other rules but you can use utility functions provided by scikit-learn to take care of them. A check_estimator
function is also provided to exhaustively verify that your implementation is correct. An official code template is also provided.
Transformer use case: verifying model’s input
scikit-learn estimators were originally designed to operate on numpy arrays (although there is current ongoing work to better interface with pandas Data Frames). For practical purposes, this means our estimators do not have a notion of column names (only input shape is verified to raise errors): if columns are shuffled, the transformer/estimator will not complain, but the prediction will be meaningless.
Our custom transformer (an object that implements fit and transform) adds this capability: when used in a Pipeline
object, it will verify that the we are getting the right input columns. The (commented) implementation looks as follows:
The sklearn.utils.validation
module provides utility functions to pass some of check_estimator
tests without having to implement the logic ourselves (I actually had to perform a few modifications to my original implementation, to fix errors thrown by check_estimator
). These utility functions transform inputs ( check_X_y
, check_array
) to return the expected format (numpy arrays) and throw the appropriate exceptions when this is not possible. check_is_fitted
only raises an error if a call to predict
is attempted without fitting the model first.
We now verify that our transformer passes all the tests:
Passing all tests is not absolutely necessary for your transformer (or estimator) to integrate correctly with other scikit-learn modules, but doing so assures that your implementation is robust by handling common scenarios on behalf of the user (e.g. passing a 2D array with one column as y instead of a 1D array) and throwing informative errors. Given the large user base scikit-learn has, this is a must, however, for some very customized implementation, passing all the tests is simply not possible, as we will see in the the custom estimator use case.
For now, let’s verify that our transformer plays nicely with Pipeline and GridSearchCV:
We now verify that our transformer throws an error if a column is missing:
If we add a column but switch to non-strict mode, we get a warning instead of an error:
Estimator use case: logging model’s predictions
Say we want to log all our predictions to monitor a production model, for the sake of example, we will just use the logging
module but this same logic applies to other methods such as saving predictions to a database. There are a few nuances here. Pipeline
requires all intermediate steps to be transformers (fit/transform), which means we can only add our model (the one that implements predict) at the end.
Since we cannot split our logging in two steps, we have to wrap an existing estimator and add the logging functionality to it, from the outside, our custom estimator will just look like another standard estimator.
The 3 considerations that apply for transformers apply for estimators, plus a fourth one (copied directly from scikit-learn’s documentation):
check_estimator
has a generate_only
parameter that let us run checks one by one instead of failing at the first error. Let's use that option to check LoggingEstimator
.
Names aren’t very informative, so I took a look at the source code.
check_parameters_default_constructible
checks that the estimator __init__
parameters are of certain type, our estimator passes a class as an argument, that's why it breaks, but it shouldn't be an issue when interfacing with other components. I don't know why they restrict the types of arguments, my guess is that they want to avoid problems with objects that don't play nicely with the multiprocessing
module.
check_no_attributes_set_in_init
is also about __init__
arguments, according to the spec, we should not set any attributes other than the arguments, but we need them for logging to work, it should not affect either.
Finally check_supervised_y_2d
, checks that if a 2D numpy array is passed to fit
a warning is issued, since it has to be converted to a 1D array, our custom estimator wraps any estimator, which could have multi-output, so we cannot use the utility functions to fix this.
The bottom line is that check_estimator
runs a very strict test suite, if your estimator does not pass all the tests, it does not mean it won't work, but you'll have to be more careful about your implementation.
Let’s now see our pipeline in action, note that we are also including our InputGuard
, we change the underlying model in LoggingEstimator
to demonstrate that it works with any estimator.
Let’s now configure the logging
module and enable it in our custom estimator:
The following line shows our logging in effect:
Since we implemented __getattr__
, any model-specific attribute also works, let's get the linear model coefficients:
Appendix: making our estimator work with pickle
(or any other pickling mechanism)
Pickling an object means saving it to disk. This is useful if we want to fit and then deploy a model ( be careful when doing this!) but it is also needed if we want our model to work with the multiprocessing
module. Some objects are picklable but some others are not (this also depends on which library you are using). logger
objects do not work with the pickle
module but we can easily fix this by deleting it before saving to disk and initializing it after loading, this boils down to adding two more methods: __getstate__
and __setstate__
, if you are interested in the details, read this.
Closing remarks
In this post we showed how to develop transformers and estimators compatible with scikit-learn. Given how many details the API has, using the check_estimator
function will guide you through the process. However, if your implementations contains non-standard behavior (like ours), your custom objects will fail the tests even if they integrate correctly with other modules. In such case, you'll have to be careful about your implementation, using check_estimator
with generate_only=True
is useful for getting a list of failing tests and deciding whether it is acceptable or not.
Following the scikit-learn API spec gives you access to a wide set of ML tools so you can focus on implementing your custom model and still use other modules for grid search, cross validation, etc. This is a huge time saver for ML projects.
Source code for this post is available here.
Found an error in this post? Click here to let us know.
This post was generated using scikit-learn version:
Originally published at ploomber.io