griegler/octnet

Accuracy when running code in train_mn10_r64.lua

YiruS opened this issue · 1 comments

YiruS commented

Hi,

Recently I installed torch and ran the example code from train_mn10_r64.lua. There are 20 epochs and initial learning rate is 0.001 with batch size of 32 using Adam optimizer, everything is as default.
However, the test accuracy is only 0.457, as shown below:
[INFO] loading data took 0.005721[s] - n_batches 32
[INFO] net/crtrn fwd took 0.484423[s]
[INFO] test batch 29/29
[INFO] loading data took 0.004988[s] - n_batches 12
[INFO] net/crtrn fwd took 0.205677[s]
test_epoch=20, avg_f=0.457447

How can I achieve the accuracy as displayed in the octnet paper?

Thanks!

The output avg_f is not the accuracy of the model, but the average loss (cross entropy).
The following code should compute the accuracy

function common.test_epoch(opt, data_loader)
  local net = opt.net or error('no net in test_epoch')
  local criterion = opt.criterion or error('no criterion in test_epoch')
  local n_batches = data_loader:n_batches()

  net:evaluate()

  local avg_f = 0
  local accuracy = 0
  local n_samples = 0
  for batch_idx = 1, n_batches do
    print(string.format('[INFO] test batch %d/%d', batch_idx, n_batches))

    local timer = torch.Timer()
    local input, target = data_loader:getBatch()
    print(string.format('[INFO] loading data took %f[s] - n_batches %d', timer:time().real, target:size(1)))

    local timer = torch.Timer()
    local output = net:forward(input)
    output = output[{{1,target:size(1)}, {}}]
    local f = criterion:forward(output, target)
    print(string.format('[INFO] net/crtrn fwd took %f[s]', timer:time().real))
    avg_f = avg_f + f
    
    local maxs, indices = torch.max(output, 2)
    for bidx = 1, target:size(1) do
      if indices[bidx][1] == target[bidx] then
        accuracy = accuracy + 1
      end
      n_samples = n_samples + 1
    end
  end 
  avg_f = avg_f / n_batches
  accuracy = accuracy / n_samples

  print(string.format('test_epoch=%d, avg_f=%f, accuracy=%f', opt.epoch, avg_f, accuracy))
end