/loss4tf

Various loss function in tf

Primary LanguagePythonMIT LicenseMIT

Loss function in tf

Various loss function in tf, some of them were implemented by official tensorflow, but I prefer to pack them into a Layer class.

Supported Loss Functions

Recognition

  • CTC Loss
  • Focal CTC Loss
  • CTC Center Loss
    • pytorch version from here

CTC Center Loss comparison

from unit_test import test_ctc_center_loss

n_class = 100
dims = 128
x = np.random.normal(size=(32, 16, dims)).astype(np.float32)
labels = np.random.randint(0, n_class, (32, 16)).astype(np.int32)
test_ctc_center_loss(x, labels, n_class, dims)

results:

torch loss: 127.31403
tf    loss: 127.314026
loss  diff: 7.6293945e-06

Classification

  • BCE Loss
  • CE Loss
  • Center Loss

Object Detection

  • Smooth L1 Loss

Segmentation

  • Dice BCE Loss
  • Dice Loss
  • IoU Loss

Usage

from losses import *

# e.g. CTC Loss
loss = CTCLoss()
x = tf.random.normal([2, 10, 20])
y = tf.random.uniform([2, 10], maxval=20, dtype=tf.int32)
out = loss(x, y)
print(out)