piEsposito/blitz-bayesian-deep-learning

Error in BayesianRNN class

LinglongQian opened this issue · 1 comments

In the code of BayesianRNN, there has two self.bias = None, the first one should be self.bias_mu = None I suppose.
And this class also lack the definition of self.bias_prior_dist, which used in def sharpen_posterior(self, loss, input_shape)

class BayesianRNN(BayesianModule):
    """
    implements base class for B-RNN to enable posterior sharpening
    """
    def __init__(self,
                 sharpen=False):
        super().__init__()
        
        self.weight_ih_mu = None
        self.weight_hh_mu = None
        self.bias = None
        
        self.weight_ih_sampler = None
        self.weight_hh_sampler = None
        self.bias_sampler = None

        self.weight_ih = None
        self.weight_hh = None
        self.bias = None
        
        self.sharpen = sharpen
        
        self.weight_ih_eta = None
        self.weight_hh_eta = None
        self.bias_eta = None
        self.ff_parameters = None
        self.loss_to_sharpen = None
        
    
    def sample_weights(self):
        pass
    
    def init_sharpen_parameters(self):
        if self.sharpen:
            self.weight_ih_eta = nn.Parameter(torch.Tensor(self.weight_ih_mu.size()))
            self.weight_hh_eta = nn.Parameter(torch.Tensor(self.weight_hh_mu.size()))
            self.bias_eta = nn.Parameter(torch.Tensor(self.bias_mu.size()))
            
            self.ff_parameters = []

            self.init_eta()
    
    def init_eta(self):
        stdv = 1.0 / math.sqrt(self.weight_hh_eta.shape[0]) #correspond to hidden_units parameter
        self.weight_ih_eta.data.uniform_(-stdv, stdv)
        self.weight_hh_eta.data.uniform_(-stdv, stdv)
        self.bias_eta.data.uniform_(-stdv, stdv)
    
    def set_loss_to_sharpen(self, loss):
        self.loss_to_sharpen = loss
    
    def sharpen_posterior(self, loss, input_shape):
        """
        sharpens the posterior distribution by using the algorithm proposed in
        @article{DBLP:journals/corr/FortunatoBV17,
          author    = {Meire Fortunato and
                       Charles Blundell and
                       Oriol Vinyals},
          title     = {Bayesian Recurrent Neural Networks},
          journal   = {CoRR},
          volume    = {abs/1704.02798},
          year      = {2017},
          url       = {http://arxiv.org/abs/1704.02798},
          archivePrefix = {arXiv},
          eprint    = {1704.02798},
          timestamp = {Mon, 13 Aug 2018 16:48:21 +0200},
          biburl    = {https://dblp.org/rec/journals/corr/FortunatoBV17.bib},
          bibsource = {dblp computer science bibliography, https://dblp.org}
        }
        """
        bs, seq_len, in_size = input_shape
        gradients = torch.autograd.grad(outputs=loss,
                                        inputs=self.ff_parameters,
                                        grad_outputs=torch.ones(loss.size()).to(loss.device),
                                        create_graph=True,
                                        retain_graph=True,
                                        only_inputs=True)
        
        grad_weight_ih, grad_weight_hh, grad_bias = gradients
        
        #to generate sigmas on the weight sampler
        _ = self.sample_weights()
        
        weight_ih_sharpened = self.weight_ih_mu - self.weight_ih_eta * grad_weight_ih + self.weight_ih_sampler.sigma
        weight_hh_sharpened = self.weight_hh_mu - self.weight_hh_eta * grad_weight_hh + self.weight_hh_sampler.sigma
        bias_sharpened = self.bias_mu - self.bias_eta * grad_bias + self.bias_sampler.sigma
        
        if self.bias is not None:
            b_log_posterior = self.bias_sampler.log_posterior(w=bias_sharpened)
            b_log_prior_ = self.bias_prior_dist.log_prior(bias_sharpened)
            
        else:
            b_log_posterior = b_log_prior = 0
        
        
        self.log_variational_posterior += (self.weight_ih_sampler.log_posterior(w=weight_ih_sharpened) + b_log_posterior + self.weight_hh_sampler.log_posterior(w=weight_hh_sharpened)) / seq_len
        
        self.log_prior += self.weight_ih_prior_dist.log_prior(weight_ih_sharpened) + b_log_prior + self.weight_hh_prior_dist.log_prior(weight_hh_sharpened) / seq_len
        
        return weight_ih_sharpened, weight_hh_sharpened, bias_sharpened

Is there also lack brackets?
self.log_prior += self.weight_ih_prior_dist.log_prior(weight_ih_sharpened) + b_log_prior + self.weight_hh_prior_dist.log_prior(weight_hh_sharpened) / seq_len