Kyun Kyu Kim#1,
Min Kim#2,
Sungho Jo*2,
Seung Hwan Ko*3,
Zhenan Bao*1
1Stanford, CA, USA, 2Korea Advanced Institute of Science and Technology (KAIST), Daejeon, Korea, 3Seoul National University, Seoul, Korea
#denotes equal contribution
in Nature Electronics
This repo is written in Python 3.9. Any Python version > 3.7 will be compatible with our code.
This repo is tested on Windows OS with CUDA 11. For the same environment, you can install pytorch with the below command line, otherwise, please install Pytorch by following instructions on the official pytorch website: https://pytorch.org/get-started/locally/
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
Python 3 dependencies:
- Pytorch 1.12
- attrs
- numpy
- PyQt5
- scikit-learn
We provide a conda environment setup file having all the dependencies required for running our code. You can create a conda environment tdc by running below command line:
conda env create -f environment.yml
Our training steps are divided into two seperate parts: 1. TD-C Learning, 2. Rapid Adaptation We provide codes and experiment environments for adopting our learning method, including data parsing, training code, basic ui for collecting few-shot demonstration and making real-time inference.
TD-C learning is an unsupervised learning methods that utilize the jittering signal augmentation and time dependent contrastive learning to learn sensor representations with unlabeled random motion data. Here we show data format used to run our code and how to run our code with sample unlabeled data.
To run the code, first prepare byte encoded pickle files containing sensor signals in a dictionary data structure with key 'sensor' and value sequential sensor signals: {'sensor': array(s1, s2, ....)} Our code will read and parse all pickle files in ./data/train_data with above dictionary format.
We found out that the best performing window_size and data embedding size are dependent on the total amount of data, data collection frequency and so on. You can change different hyperparameter settings by simply modifying values in params.py file.
Run
python tdc_train.py
To allow pretrained model to be adapted to perform various different tasks, we applied few-shot transfer leraning and metric-based inference mechanism for real-time inference. Here we provide basic ui system implemented with PyQT5 which allows users to collect few-shot demo and make real-time inference.
We provide basic UI code in ui directory
The UI contains two buttons: 1. Collect Start, 2. Start Prediction and two Widgets: 1. status widget showing current prediction, 2. sensor widget showing current sensor values.
The system starts to record few-shot labeled data from demonstration when user press "Collect Start" button. After providing all required demonstration, make sure to press "Start Prediction" button, so that the system starts to transfer learn the model.
In transfer_learning_base.py file we provide transfer learning, data embedding and metric-based inference functions
In our system, the system do transfer learning with provided few-shot demonstrations. The number of transfer epochs can be modified by changing transfer_epoch variable in params.py.
After running a few transfer epoch, the system encode few-shot label data with transferred model to generate demo_embeddings. These embeddings are then used as references for Maximum Inner Product Search(MIPS). Given a window of sensor values, the model generate its embedding and phase variable. If the phase variable exceeds predefined threshold, the system perform MIPS and corresponding prediction is appeared on the status widget. Otherwise, the system regard the current state as a resting state.