Building Custom Machine Learning Models
Sometimes, in order to meet a specific business need it is best to create a custom machine learning model. In this article we discuss how to create such models. We show how use the custom machine learning models within the scikit-learn ecosystem. For example, we can apply scikit-learn’s GridSearchCV
on our custom machine learning models to find the best hyperparameters.
Basic Components of a Machine Learning Model
A (supervised) machine learning model has two main components: fit
and predict
. We use the fit
method to learn from data and use predict
to make predictions on new data.
class MLModel():
def __init__(self):
pass
def fit(self, X, y):
"train the model on a dataset"
def predict(self, X):
"predict y on unseen dataset"
A Simple Custom Machine Learning Model For Classifying Iris Species
Let’s consider the classic Iris dataset. The dataset consists of samples from three Iris species (Iris setosa, Iris virginica, Iris versicolor) with four features (sepal length, sepal width, petal length, petal width). We can load it from sklearn.datasets
.
from sklearn.datasets import load_iris
X, y = load_iris(return_X_y=True, as_frame=True)
X
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | |
---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 |
1 | 4.9 | 3.0 | 1.4 | 0.2 |
2 | 4.7 | 3.2 | 1.3 | 0.2 |
3 | 4.6 | 3.1 | 1.5 | 0.2 |
4 | 5.0 | 3.6 | 1.4 | 0.2 |
... | ... | ... | ... | ... |
145 | 6.7 | 3.0 | 5.2 | 2.3 |
146 | 6.3 | 2.5 | 5.0 | 1.9 |
147 | 6.5 | 3.0 | 5.2 | 2.0 |
148 | 6.2 | 3.4 | 5.4 | 2.3 |
149 | 5.9 | 3.0 | 5.1 | 1.8 |
150 rows × 4 columns
X.describe()
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | |
---|---|---|---|---|
count | 150.000000 | 150.000000 | 150.000000 | 150.000000 |
mean | 5.843333 | 3.057333 | 3.758000 | 1.199333 |
std | 0.828066 | 0.435866 | 1.765298 | 0.762238 |
min | 4.300000 | 2.000000 | 1.000000 | 0.100000 |
25% | 5.100000 | 2.800000 | 1.600000 | 0.300000 |
50% | 5.800000 | 3.000000 | 4.350000 | 1.300000 |
75% | 6.400000 | 3.300000 | 5.100000 | 1.800000 |
max | 7.900000 | 4.400000 | 6.900000 | 2.500000 |
It’s clear that sepal length, sepal width, petal length, and petal width should be positive numbers.
Let’s create a simple custom machine learning model: train the model using Support Vector Classification (SVC) but only make predictions if all the features are positive, return unknown
otherwise.
import pandas as pd
from sklearn import svm
class MLModel():
def __init__(self, kernel='linear', C=1.0):
self.kernel=kernel
self.C = C
self.clf = svm.SVC(C=self.C, kernel=self.kernel)
def fit(self, X: pd.DataFrame, y):
"train the model on a dataset"
self.clf.fit(X, y)
def predict(self, X: pd.DataFrame):
"predict y on unseen dataset"
predictions = []
for _, row in X.iterrows():
if (row > 0).all():
prediction = self.clf.predict(row.to_frame().T)[0]
else:
prediction = 'unknown'
predictions.append(prediction)
return predictions
model = MLModel()
We can now train the model on the Iris dataset.
model.fit(X, y)
To try our trained model we create three test samples. Note that the second sample has 0.0
sepal length and the third sample has sepal width equal to -1.0
.
X_new = pd.DataFrame({
'sepal length (cm)': [2.3, 0, 6.3],
'sepal width (cm)': [2.5, 3.0, -1],
'petal length (cm)': [1.4, 4.2, 5.4],
'petal width (cm)': [2.0, 2.3, 1.9]
})
X_new
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | |
---|---|---|---|---|
0 | 2.3 | 2.5 | 1.4 | 2.0 |
1 | 0.0 | 3.0 | 4.2 | 2.3 |
2 | 6.3 | -1.0 | 5.4 | 1.9 |
As expected our model only returns a prediction for first sample and returns unknown
for the second and the third samples.
model.predict(X_new)
[0, 'unknown', 'unknown']
Using Custom Machine Learning Models within the Scikit-learn Ecosystem
In order to use our custom machine learning model within the scikit-learn ecosystem, we need to provide a few other methods:
-
get_params
: returns a dict of parameters of the machine learning model. -
set_params
: takes a dictionary of parameters as input and sets the parameter of the machine learning model. -
score
: provides a default evaluation criterion for the problem they are designed to solve.
We can either implement these methods ourselves or just inherit these methods from sklearn.base.BaseEstimator
and sklearn.base.ClassifierMixin
.
BaseEstimator
provides the implementation of the get_params
and set_params
methods. ClassifierMixin
provides the implementation of the score
method as the mean accuracy.
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.model_selection import train_test_split
class MLModel(BaseEstimator, ClassifierMixin):
def __init__(self, kernel='linear', C=1.0):
self.kernel=kernel
self.C = C
self.clf = svm.SVC(C=self.C, kernel=self.kernel)
def fit(self, X: pd.DataFrame, y):
"train the model on a dataset"
self.clf.fit(X, y)
def predict(self, X: pd.DataFrame):
"predict y on unseen dataset"
predictions = []
for _, row in X.iterrows():
if (row > 0).all():
prediction = self.clf.predict(row.to_frame().T)[0]
else:
prediction = 'unknown'
predictions.append(prediction)
return predictions
model = MLModel()
Since we’ve defined MLModel
as a subclass of BaseEstimator
and ClassifierMixin
, we can use get_params
to retrieve all the parameters and use score
to compute the mean accuracy on the test dataset.
model.get_params()
{'C': 1.0, 'kernel': 'linear'}
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)
model.fit(X_train, y_train)
model.score(X_test, y_test)
1.0
Our custom machine learning model also works fine with scikit-learn’s GridSearchCV
.
from sklearn.model_selection import GridSearchCV
parameters = {'kernel':('linear', 'rbf'), 'C':[1, 10]}
clf = GridSearchCV(model, parameters)
clf.fit(X, y)
GridSearchCV(estimator=MLModel(),
param_grid={'C': [1, 10], 'kernel': ('linear', 'rbf')})
clf.best_params_
{'C': 1, 'kernel': 'linear'}
Note that for regression problems we need to use RegressorMixin
instead of ClassifierMixin
, which implements the coefficient of determination of the prediction as the score
method. See here for more details.