https://ccnets.org/tutorials?category=CelebA
Causal Cooperative Nets(CCNets) is the brand new ML frameworks that are composed of three neural network models; explainer, reasoner and producer. This ML training is for discovering causal relationships in statistics and explaining black boxes in the ML modeling.
Compared to supervised or generative learning, which trains one or two models on the data, this method cooperatively trains three models (explainer, reasoner, producer) simultaneously. Normally, supervised learning receives an observation as input and predicts its label that learn association between the observation space and label space. However, CCNets take an observation and its label as input that learn causal relationship between the observation space and label space.
-
Explainer: The explainer model learns data explanation for labels. It receives observed data as input and outputs latent vector(causal explantion vector) in the explantory space. Neural networks that are used for Classifier or regressor can be placed as Explainer in CCNets.
-
Reasoner: The reasoner model learns to infer a label with an explanation. It receives observed data and a causal explantion vector as input and outputs a inferred label in the label space. Neural networks that are used for Classifier or regressor can be placed as Reasoner in CCNets.
-
Producer: The producer model learns to generate data with an explanation. It receives a label and a causal explantion vector as input and outputs generated data.
git clone https://github.com/junho-ccnets/causal-learning
pip3 install -r requirements.txt
If you just want to test training, you should download your dataset
dataroot = <data_path>
trainset = dset.CelebA(root=dataroot, split = "train", transform=transforms.Compose([
transforms.Resize(n_img_sz),
transforms.ToTensor(),
transforms.CenterCrop(n_img_sz),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]), download = True)
dataroot
: Root directory for datasetworkers
: Number of workers for dataloaderbatch_size
: Batch size during trainingn_img_ch
: image channelsn_img_sz
: image width & heightsdim_explanation
: number of dimensions of causal explanation vector in explantory spacedim_label
: A dimension of labelslr
: Learning rate for Celeb A imagesnum_epochs
: Number of training epochsstep_size
: Learning rate for optimizersbeta1
: coefficients used for computing running averages of gradientgamma
: scheduler decay ratemanualSeed
: Set random seed for reproducibility
ⓒ 2022 CCNets, Inc All Rights reserved.