API design discussion
Closed this issue · 5 comments
Summary
It seems to me that revert
and external cache
s can be completely avoided by using the usual fit
and predict
approach for machine learning pipelines. And this can be done in a functional, immutable fashion.
Discussion
I skimmed a couple sections in your book. Here's a quote from Chapter 7:
A very common workflow in geospatial data science consists of:
- Transforming the data to an appropriate sample space for geostatistical analysis
- Doing additional modeling to predict variables in new geospatial locations
- Reverting the modeling results with the saved pipeline and cache
I also skimmed through the mineral deposits example in Chapter 12 to see this approach in action. I might have misunderstood something, but it seems like the overall goal is to fit a Kriging model to some data, and then use that Kriging model to make predictions at the various points (or centroids) on a grid.
The approach taken by TableTransforms (and GeoStats) is to apply a pipeline and save an external cache
, make a prediction, and then use the cache to revert
.
I might be getting the details below a bit wrong, because GeoStats probably has special handling for the spatial coordinates (which are input features to the Kriging model), but overall it seems like the traditional machine learning approach can be applied, something like this:
model = fit(untrained_model, data, target)
interp = predict(model, grid)
...where untrained_model
is a pipeline that has some sort of Kriging/Interpolation model at the end of it. This way the user doesn't need to handle the cache
object themselves. This can still be done with a functional, immutable approach. For example, Center
might look something like this:
struct Center{T}
means::T
end
Center() = Center(nothing)
function fit(c::Center, data)
# means = ...
Center(means)
end
Or, of course, you could use a Union{T, Nothing}
:
struct Center
means::Union{Vector{Float64}, Nothing}
end
Hi @CameronBieganek ! Thanks for sharing these ideas. We carefully designed our transforms so that they only store the hyperparameters, which are static in memory and trivial to pass around in parallel jobs. For example, we have the PCA transform that only stores the output dimension and the column names. The cache is much more expensive to pass around as it contains the principal component basis.
Another issue with the design you suggested is that it glues the transform hyperparameters with the specific data set used in the apply step. This was one of the main reasons we did not advance with MLJ.jl nor FeatureTransforms.jl. If you take a look into the mineral deposit example in the book, we fit the Kriging model to the drill hole samples (set of points along a trajectory) and predict on the Cartesian grid. Saving information about the trajectory doesn't help in the revert step. We need to save caches that are very complicated sometimes.
Besides the two issues above, we are aiming for a more general approach to pipeline optimization, which may include neural networks. If we had to save the weights of a neural network model inside the struct that is passed around as a transform, we would be screwed.
I will close the issue, but feel free to continue the discussion here or on Zulip.
I'm not really convinced by your argument about parallel processing. Note that in the example I gave above the transforms are still immutable, and transforms/models that have not yet been fit have all their fitted parameter fields set to nothing
, so they are still lightweight to pass around.
I now understand why you have revert
. In the mineral deposit example, the transforms are all applied to the target values, not the feature values. So revert
is analogous to TransformedTargetRegressor
in scikit-learn and TransformedTargetModel
in MLJ. The TransformedTargetRegressor
and TransformedTargetModel
types are admittedly a little clunky, but they get the job done, and they have the advantage that the user can stick with the usual fit
and predict
ML workflow without having to manually handle the cache
object. Perhaps there is a cleaner API somewhere between the sklearn approach and the current approach in TableTransforms.jl.
Taking a higher level perspective, the main reason I don't like cache
is because it makes the API more verbose and complicated (in my opinion), and because it leaks implementation details to the user. There is no reason why the user should have to manually handle and worry about cache
objects.
Apologies, but now I'm going to move on to a more general design discussion. Hopefully you can take it in the spirit of trying to improve the ML ecosystem. I think TableTransforms.jl has many nice aspects, but I think there are also some elements of the design that could be improved. I could open separate issues, but this is probably more of an open-ended design discussion. (I don't have Zulip or Slack.)
Taking a look at this line of code from the book:
interp = samples |> InterpolateNeighbors(blocks, models...)
That's a rather odd table transform. It takes a table with 2,000 rows and returns a table with ~100,000 rows. I think a better conceptual model for what is going on is that there are two steps: fit
a model pipeline with samples
, and then use that pipeline to predict
on the new data blocks
. Also, introducing more verbs (lowercase function calls) into the workflow makes it more clear where the action is happening and what kind of action is happening.
It also might be helpful to make a distinction between table queries and ML/statistics/data science transformations (which normally preserve the number of rows). Right now it is easy to make an invalid ML pipeline:
model = Filter(row -> row.a > 10) → RandomForestRegressor()
It's ok to filter the training data, but it is not ok to filter the prediction data. When you put a model into production, you usually need to make a prediction on every input observation. Even before production, in the context of a train/test split, it could so happen that you filter out all the data in the test set. At a minimum, your specified train-test split fraction will not hold if you apply the filtering as part of the model (after the train-test split). (I know I'm extrapolating your design by attaching an ML model at the end of the pipe, but that is the typical ML approach.)
This is a bit of bikeshedding, but apply
and reapply
are very generic verbs. I prefer fit
, transform
, and fit_transform
because they are more specific and because they are already well established terms in the ML community. (At least within the scikit-learn community, but that accounts for a pretty large percentage of ML practice.)
I originally typed the following up, but then I realized that you can do the exact same thing in scikit-learn. It's just that the documentation makes it clear that the normal pattern is to attach an estimator to the end of a pipeline, and then provide the fit
and predict
methods with the raw (untransformed) X
and y
data.
Making transforms callable is cute, but it makes it easy for users to make mistakes. One can easily use an untrained pipeline for both training and prediction, like this:
transform = ZScore() → EigenAnalysis(:V) modelfit = fit(model, transform(Xtrain), ytrain) ŷ = predict(modelfit, transform(Xtest))
Thank you for the feedback. Below are specific comments.
Taking a higher level perspective, the main reason I don't like
cache
is because it makes the API more verbose and complicated (in my opinion), and because it leaks implementation details to the user. There is no reason why the user should have to manually handle and worry aboutcache
objects.
Maybe you are misunderstanding the goal of the cache
in the revert
step. It is not something that sk-learn nor MLJ support as far as I know. It is about "undoing" pipelines, it has nothing to do with fit/predict. Our fit/predict is quite clean here. It consists of creating the transform object on "training" data and calling it as a functor on "test" data. We are still brainstorming this API in the Learn
transform in StatsLearnModels.jl, which adheres to TableTransforms.jl.
Taking a look at this line of code from the book:
interp = samples |> InterpolateNeighbors(blocks, models...)That's a rather odd table transform. It takes a table with 2,000 rows and returns a table with ~100,000 rows. I think a better conceptual model for what is going on is that there are two steps:
fit
a model pipeline withsamples
, and then use that pipeline topredict
on the new datablocks
. Also, introducing more verbs (lowercase function calls) into the workflow makes it more clear where the action is happening and what kind of action is happening.
Try to think outside the fit/predict
box of other frameworks. The InterpolateNeighbors
transforms is a geospatial transform, it doesn't have to do with statistical learning models where you have a "train" table and a "test" table of features. It only takes a geospatial domain (blocks) and performs interpolation.
Your entire discussion should probably be narrowed down to our Learn
transform, which is more related to what MLJ and sk-learn can do.
It also might be helpful to make a distinction between table queries and ML/statistics/data science transformations (which normally preserve the number of rows). Right now it is easy to make an invalid ML pipeline:
model = Filter(row -> row.a > 10) → RandomForestRegressor()It's ok to filter the training data, but it is not ok to filter the prediction data.
Thanks but I disagree. The idea of the TransformsBase.jl api is to be able to combine all sorts of api agnostically. We don't need categorization to create sophisticated pipelines involving geometric, statistical, cleaning, etc transforms.
Also, your argument here and in other parts of the text below is not very good. You are saying something like: "users can do messy things, so we should limit the power of our interface to only handle a subset of features"
This is a bit of bikeshedding, but
apply
andreapply
are very generic verbs. I preferfit
,transform
, andfit_transform
because they are more specific and because they are already well established terms in the ML community. (At least within the scikit-learn community, but that accounts for a pretty large percentage of ML practice.)
As far as I understand it, apply
and reapply
have a different purpose than fit
and fit_transform
in other frameworks. Also, the latter are jargon, and our transforms go beyond ML transforms only.
Maybe you are misunderstanding the goal of the
cache
in therevert
step. It is not something that sk-learn nor MLJ support as far as I know. It is about "undoing" pipelines, it has nothing to do with fit/predict.
Yes, they do support it, as I mentioned above. The revert
functionality is handled by TransformedTargetRegressor
in sklearn and TransformedTargetModel
in MLJ. I admit that those types do not provide the prettiest interface ever, but they do have some advantages (which I also mentioned above).
Aside from undoing target transformations, which sklearn and MLJ provide, I'm not sure what other use case there is for "undoing" a pipeline.
Try to think outside the fit/predict box of other frameworks.
I'm perfectly capable of thinking outside the fit/predict box. I just happen to think that fit/predict is a better abstraction for what is going on here. Eventually you end up contorting yourself to adhere to the "Everything is a table transform" philosophy.
The InterpolateNeighbors transforms is a geospatial transform, it doesn't have to do with statistical learning models where you have a "train" table and a "test" table of features. It only takes a geospatial domain (blocks) and performs interpolation.
Perhaps geospatial scientists are less interested in testing the generalization error of their models than data scientists are. You can and probably should split your mineral deposit samples into train and test sets so that you can empirically estimate how accurate your Kriging model is.
Scikit-learn does have a Kriging model, which uses fit
and predict
methods. I find their interface more intuitive.
Also, your argument here and in other parts of the text below is not very good. You are saying something like: "users can do messy things, so we should limit the power of our interface to only handle a subset of features"
A good API helps the user avoid errors. The power of a table query language like SQL comes from its limited scope. Here's a quote from "Database Systems: The Complete Book" (page 38) in answer to the question "Why do we need a special query language?":
The surprising answer is that relational algebra is useful because it is less powerful than C or Java. That is, there are computations one can perform in any conventional language that one cannot perform in relational algebra. An example is: determine whether the number of tuples in a relation is even or odd. By limiting what we can say or do in our query language, we get two huge rewards — ease of programming and the ability of the compiler to produce highly optimized code.
Perhaps geospatial scientists are less interested in testing the generalization error of their models than data scientists are. You can and probably should split your mineral deposit samples into train and test sets so that you can empirically estimate how accurate your Kriging model is.
Check our paper, which is all about generalization error in geospatial settings: https://www.frontiersin.org/articles/10.3389/fams.2021.689393/full
We are in the process of porting all these validation methods to be compatible with TableTransforms.jl pipelines. They are already implemented in GeoStats.jl as you can see here: https://juliaearth.github.io/GeoStatsDocs/stable/validation.html
Also watch our JuliaCon talk for more examples: https://www.youtube.com/watch?v=75A6zyn5pIE The things that we can do are already much more sophisticated than what sk-learn or mlj can do because they are not flexible enough to handle geospatial domains, efficient (lazy) partitioning schemes, etc.