idstcv/ZenNAS

How to do `entropy_forward` for CSP network?

1chimaruGin opened this issue · 2 comments

My block look like this.

BottleneckCSP(
  (cv1): Conv(
    (conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): Mish()
  )
  (cv2): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (cv3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (cv4): Conv(
    (conv): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): Mish()
  )
  (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act): Mish()
  (m): Sequential(
    (0): Bottleneck(
      (cv1): Conv(
        (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): Mish()
      )
      (cv2): Conv(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): Mish()
      )
    )
    (1): Bottleneck(
      (cv1): Conv(
        (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): Mish()
      )
      (cv2): Conv(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): Mish()
      )
    )
  )
)

forward(self, x)

def forward(self, x):
        d = self.m(self.cv1(x))
        y1 = self.cv3(d)
        y2 = self.cv2(x)
        return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))

I think you can simply call

def compute_nas_score(gpu, model, mixup_gamma, resolution, batch_size, repeat, fp16=False):

The zen-score will be saved in info['avg_nas_score']

Sorry @MingLin-home, I opened issue on wrong Repo. I'm asking for Lightning NAS (MAE-DET).
By the way, thanks for the kind response.