Skip to content

[python-package] Can a LGBMClassifier be refit? #7157

@bhvieira

Description

@bhvieira

I am migrating from the python API to the sklearn one, but I need the .refit method. I didn't fill this as a bug because I'm not sure if feature parity would be considered a bug by the LGBM team.
If we naively try to reassign the Booster object, it breaks the .predict_proba method.

Is there another way to refit a LGBMClassifier?

import numpy as np
from sklearn.datasets import make_classification
from lightgbm import LGBMClassifier
from copy import deepcopy

# Create simple synthetic data
X_train, y_train = make_classification(n_samples=1000, n_features=20, n_classes=5, 
                                       n_informative=15, random_state=42, n_redundant=0)
X_adapt, y_adapt = make_classification(n_samples=100, n_features=20, n_classes=5,
                                       n_informative=15, random_state=43, n_redundant=0)

# Train classifier
clf = LGBMClassifier(objective="multiclass", verbose=-1, n_estimators=10)
clf.fit(X_train, y_train)

# Get predictions before refitting
probs_before = clf.predict_proba(X_adapt)
print(f"Before refit shape: {probs_before.shape}")

# Try to refit with adaptation data
clf_refit = deepcopy(clf)
lgbm_refit = clf_refit._Booster.refit(X_adapt, y_adapt, decay_rate=0.0, verbose=-1)

# Reassign refitted booster back to the classifier
clf_refit._Booster = lgbm_refit

# Try to get predictions after refitting
probs_after = clf_refit.predict_proba(X_adapt)
print(f"After refit shape: {probs_after.shape}")

# This will fail with broadcast error
print(f"Changed? {not np.allclose(probs_before, probs_after)}")

This will print

Before refit shape: (100, 5)
After refit shape: (100,)

I have
name = "lightgbm"
version = "4.6.0"

and
name = "scikit-learn"
version = "1.7.2"

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions