w3c/machine-learning-workshop

In-browser training

Opened this issue · 3 comments

@cynthia's Machine Learning in Web Architecture talk makes a (wild) guess:

Less than 1% of users will train in the browser context

The current in-browser efforts (e.g. those pursued by the W3C's Machine Learning for the Web Community Group) are focused on inference rather than training. This is due to pragmatic reasons: limited availability of respective platform APIs to make this process efficient, also resource restrictions of the browser architecture not optimized for such a demanding task.

For in-browser inference, a model with a total weight size somewhere in the ballpark of ~100 MB starts to be too slow on a typical desktop hardware. For training, more memory and compute is required, so possibly even smaller models than that will be too slow to train in a browser to be useful in most use cases. @huningxin has probably couple of pointers to model size vs performance evaluations.

My questions:

  • Are non-browser JS environments unhindered by resource restriction of a browser client the only feasible short-term target for JS-based training (discussed in #62)?
  • Assuming we're headed toward a future of in-browser training eventually becoming a thing, are there obvious gaps that could be bridged on the browser capabilities and APIs to smoothen the path to that future?

For example, the large amounts of data needed to train a model is currently better ingested outside the browser from a native file system. The Native File System API may address the issue of data ingestion for in-browser usage. What other such API gaps would make the memory and compute intensive task of training more feasible in the browser context?

@irealva to comment in this issue experiences of in-browser training with Teachable Machine, a project successfully used in cross-disciplinary contexts such as education, arts, for in-browser training to solve real-world problems. What are the limitations you facing and how you've worked around them? Size limits of the models, different browser behavior? Your input is valuable in determining gaps in web platform capabilities to be improved.

@nsthorat to comment on key learnings from designing TensorFlow.js for training models. My understanding is Node.js-based backend is primarily used for training, and that is mainly due to the following advantages over browser-based:

  • native filesystem access
  • ability to bind to the TensorFlow C library that makes use of various hardware acceleration paths with backends for CPU, GPU, TPU etc.
  • other things I overlooked :-)

On the browser side, it seems you've managed to work around the limitations and the complexities of WebGL with a layered approach. As you're aware, WebNN API is to solve part of the issues, as well as WebAssembly with SIMD, but I'm curious whether there are other API gaps we've overlooked that we could look into addressing with standards-based Web APIs? Specific requirements with respect to filesystem access? Better memory management (discussed in #63)?

Teachable Machine features in-browser training for all the models in the app: the image, audio, and pose model. Users add their samples to different classes and train their own custom models directly in the browser. None of the samples ever hit our servers during training.

We get around some of these limitations by restricting the training to transfer learning, so the training is much faster. For example, in the case of the image model used in Teachable Machine, users just have to re-train 2 dense layers over their own data. All of the training code can be found in this repo if you want to take a look.

A few more technical notes:

  • We lazy load the base models for transfer learning. The image one is a mobile-net V2 model that is ~1.6MB. The audio model is speech commands at ~5.9 MB. The pose model is mobilenet V1 at ~4.7MB.
  • We decided not to focus on a mobile site for Teachable Machine (you can load it on mobile, but we do not guarantee that it will work) less so because of hardware limitations on mobile (we tried training with a newer Pixel phone and despite the training being slower, one can still train a model) and more so because the interface was very tricky to adapt to mobile and we felt the desktop experience was far superior. With more time, we would have worked on a different interface optimized for mobile.
  • Given we rely on tf.js, we are bound to all the technical limitations of that library. For example, we have to warn users that they should not switch browser tabs while training, or the training will stop. For an app like Teachable Machine, these are small but very important UX challenges.

Notes on usage:

  • We don't restrict the number of samples that users can add to their models. Most users add hundreds of samples into each class; a few try adding thousands of samples or hundreds of classes and for those users there are occasional reports of the app crashing. This is probably due to the browser memory limitations, as we load all samples into a Float32Array upon before training.
  • Because all of the training is happening in a user's browser, their experience heavily depends on their own hardware. It is a challenge to provide an estimate of training time, sample size limitations, etc. to users, so we've opted not to do so and instead simply provide a counter for number of epochs trained.
  • We've found that some users really don't mind long training times (+ 30 minutes) if they are highly motivated to create their model. Given there are so few alternatives for non-technical users who want to play with ML in an easy way, these users will put up with very long training times.

@hapticdata might have some more thoughts.

As discussed yesterday here are a few additional considerations:

  • The moment of beginning training always pauses the UI thread when it first begins. We have ensured we show proper messaging before that tick so that people know during the moment that it is in fact working.
  • Users can add thousands of images and potentially crash their tab if their device does not have enough memory to process them at the time of training.
  • The performance differences between browsers with tf.js, in particular Safari, is difficult to express to users and to determine the duration training may take.