Build more Sklearn Pipeline Test - Part 2
lwgray opened this issue · 0 comments
lwgray commented
- - PredictionError
- - SilhouetteVisualizer
- - KElbowVisualizer
- - InterclusterDistance
- - GridSearchColorPlot
example below
def test_within_pipeline(self):
"""
Test that visualizer can be accessed within a sklearn pipeline
"""
X, y = load_mushroom(return_dataset=True).to_numpy()
X = OneHotEncoder().fit_transform(X).toarray()
cv = StratifiedKFold(n_splits=2, shuffle=True, random_state=11)
model = Pipeline([
('minmax', MinMaxScaler()),
('cvscores', CVScores(BernoulliNB(), cv=cv))
])
model.fit(X, y)
model['cvscores'].finalize()
self.assert_images_similar(model['cvscores'], tol=2.0)
def test_within_pipeline_quickmethod(self):
"""
Test that visualizer quickmethod can be accessed within a
sklearn pipeline
"""
X, y = load_mushroom(return_dataset=True).to_numpy()
X = OneHotEncoder().fit_transform(X).toarray()
cv = StratifiedKFold(n_splits=2, shuffle=True, random_state=11)
model = Pipeline([
('minmax', MinMaxScaler()),
('cvscores', cv_scores(BernoulliNB(), X, y, cv=cv, show=False,
random_state=42))
])
self.assert_images_similar(model['cvscores'], tol=2.0)
def test_pipeline_as_model_input(self):
"""
Test that visualizer can handle sklearn pipeline as model input
"""
X, y = load_mushroom(return_dataset=True).to_numpy()
X = OneHotEncoder().fit_transform(X).toarray()
cv = StratifiedKFold(n_splits=2, shuffle=True, random_state=11)
model = Pipeline([
('minmax', MinMaxScaler()),
('nb', BernoulliNB())
])
oz = CVScores(model, cv=cv)
oz.fit(X, y)
oz.finalize()
self.assert_images_similar(oz, tol=2.0)
def test_pipeline_as_model_input_quickmethod(self):
"""
Test that visualizer can handle sklearn pipeline as model input
within a quickmethod
"""
X, y = load_mushroom(return_dataset=True).to_numpy()
X = OneHotEncoder().fit_transform(X).toarray()
cv = StratifiedKFold(n_splits=2, shuffle=True, random_state=11)
model = Pipeline([
('minmax', MinMaxScaler()),
('nb', BernoulliNB())
])
oz = cv_scores(model, X, y, show=False, cv=cv)
self.assert_images_similar(oz, tol=2.0)
@DistrictDataLabs/team-oz-maintainers