kmkurn/pytorch-crf

I have a question about crf

Closed this issue · 7 comments

Does it improve the model performance?

Is there any difference in this version ?

Hi, I'm not sure what your question means. I assume you're asking if using a CRF layer improves performance compared to per-token softmax? It some cases it does. I'd try both and compare.

I used to use crf but not this version, and I found you deleted the older version, I found there is a huge change between two version code. and I want to know why you change it.
@kmkurn

class CRF(nn.Module):
"""
Class for learning and inference in conditional random field model using mean field approximation
and convolutional approximation in pairwise potentials term.
Parameters
----------
n_spatial_dims : int
Number of spatial dimensions of input tensors.
filter_size : int or sequence of ints
Size of the gaussian filters in message passing.
If it is a sequence its length must be equal to n_spatial_dims.
n_iter : int
Number of iterations in mean field approximation.
requires_grad : bool
Whether or not to train CRF's parameters.
returns : str
Can be 'logits', 'proba', 'log-proba'.
smoothness_weight : float
Initial weight of smoothness kernel.
smoothness_theta : float or sequence of floats
Initial bandwidths for each spatial feature in the gaussian smoothness kernel.
If it is a sequence its length must be equal to n_spatial_dims.
"""

def __init__(self, n_spatial_dims, filter_size=11, n_iter=5, requires_grad=True,
             returns='logits', smoothness_weight=1, smoothness_theta=1):
    super().__init__()
    self.n_spatial_dims = n_spatial_dims
    self.n_iter = n_iter
    self.filter_size = np.broadcast_to(filter_size, n_spatial_dims)
    self.returns = returns
    self.requires_grad = requires_grad

    self._set_param('smoothness_weight', smoothness_weight)
    self._set_param('inv_smoothness_theta', 1 / np.broadcast_to(smoothness_theta, n_spatial_dims))

def _set_param(self, name, init_value):
    setattr(self, name, nn.Parameter(torch.tensor(init_value, dtype=torch.float, requires_grad=self.requires_grad)))

def forward(self, x, spatial_spacings=None, verbose=False):
    """
    Parameters
    ----------
    x : torch.tensor
        Tensor of shape ``(batch_size, n_classes, *spatial)`` with negative unary potentials, e.g. the CNN's output.
    spatial_spacings : array of floats or None
        Array of shape ``(batch_size, len(spatial))`` with spatial spacings of tensors in batch ``x``.
        None is equivalent to all ones. Used to adapt spatial gaussian filters to different inputs' resolutions.
    verbose : bool
        Whether to display the iterations using tqdm-bar.
    Returns
    -------
    output : torch.tensor
        Tensor of shape ``(batch_size, n_classes, *spatial)``
        with logits or (log-)probabilities of assignment to each class.
    """
    batch_size, n_classes, *spatial = x.shape
    assert len(spatial) == self.n_spatial_dims

    # binary segmentation case
    if n_classes == 1:
        x = torch.cat([x, torch.zeros(x.shape).to(x)], dim=1)

    if spatial_spacings is None:
        spatial_spacings = np.ones((batch_size, self.n_spatial_dims))

    negative_unary = x.clone()

    for i in tqdm(range(self.n_iter), disable=not verbose):
        # normalizing
        x = F.softmax(x, dim=1)

        # message passing
        x = self.smoothness_weight * self._smoothing_filter(x, spatial_spacings)

        # compatibility transform
        x = self._compatibility_transform(x)

        # adding unary potentials
        x = negative_unary - x

    if self.returns == 'logits':
        output = x
    elif self.returns == 'proba':
        output = F.softmax(x, dim=1)
    elif self.returns == 'log-proba':
        output = F.log_softmax(x, dim=1)
    else:
        raise ValueError("Attribute ``returns`` must be 'logits', 'proba' or 'log-proba'.")

    if n_classes == 1:
        output = output[:, 0] - output[:, 1] if self.returns == 'logits' else output[:, 0]
        output.unsqueeze_(1)

    return output

def _smoothing_filter(self, x, spatial_spacings):
    """
    Parameters
    ----------
    x : torch.tensor
        Tensor of shape ``(batch_size, n_classes, *spatial)`` with negative unary potentials, e.g. logits.
    spatial_spacings : torch.tensor or None
        Tensor of shape ``(batch_size, len(spatial))`` with spatial spacings of tensors in batch ``x``.
    Returns
    -------
    output : torch.tensor
        Tensor of shape ``(batch_size, n_classes, *spatial)``.
    """
    return torch.stack([self._single_smoothing_filter(x[i], spatial_spacings[i]) for i in range(x.shape[0])])

@staticmethod
def _pad(x, filter_size):
    padding = []
    for fs in filter_size:
        padding += 2 * [fs // 2]

    return F.pad(x, list(reversed(padding)))  # F.pad pads from the end

def _single_smoothing_filter(self, x, spatial_spacing):
    """
    Parameters
    ----------
    x : torch.tensor
        Tensor of shape ``(n, *spatial)``.
    spatial_spacing : sequence of len(spatial) floats
    Returns
    -------
    output : torch.tensor
        Tensor of shape ``(n, *spatial)``.
    """
    x = self._pad(x, self.filter_size)
    for i, dim in enumerate(range(1, x.ndim)):
        # reshape to (-1, 1, x.shape[dim])
        x = x.transpose(dim, -1)
        shape_before_flatten = x.shape[:-1]
        x = x.flatten(0, -2).unsqueeze(1)

        # 1d gaussian filtering
        kernel = self._create_gaussian_kernel1d(self.inv_smoothness_theta[i], spatial_spacing[i],
                                               self.filter_size[i]).view(1, 1, -1).to(x)
        x = F.conv1d(x, kernel)

        # reshape back to (n, *spatial)
        x = x.squeeze(1).view(*shape_before_flatten, x.shape[-1]).transpose(-1, dim)

    return x

@staticmethod
def _create_gaussian_kernel1d(inverse_theta, spacing, filter_size):
    """
    Parameters
    ----------
    inverse_theta : torch.tensor
        Tensor of shape ``(,)``
    spacing : float
    filter_size : int
    Returns
    -------
    kernel : torch.tensor
        Tensor of shape ``(filter_size,)``.
    """
    distances = spacing * torch.arange(-(filter_size // 2), filter_size // 2 + 1).to(inverse_theta)
    kernel = torch.exp(-(distances * inverse_theta) ** 2 / 2)
    zero_center = torch.ones(filter_size).to(kernel)
    zero_center[filter_size // 2] = 0
    return kernel * zero_center

def _compatibility_transform(self, x):
    """
    Parameters
    ----------
    x : torch.Tensor of shape ``(batch_size, n_classes, *spatial)``.
    Returns
    -------
    output : torch.tensor of shape ``(batch_size, n_classes, *spatial)``.
    """
    labels = torch.arange(x.shape[1])
    compatibility_matrix = self._compatibility_function(labels, labels.unsqueeze(1)).to(x)
    return torch.einsum('ij..., jk -> ik...', x, compatibility_matrix)

@staticmethod
def _compatibility_function(label1, label2):
    """
    Input tensors must be broadcastable.
    Parameters
    ----------
    label1 : torch.Tensor
    label2 : torch.Tensor
    Returns
    -------
    compatibility : torch.Tensor
    """
    return -(label1 == label2).float() 

this is the older version

I see. From the code snippet, I don't think it's from this package. Perhaps there's another CRF package that you used in the past? I don't recognise the code as mine.

Thank you! I used to pip install crf, and crf document led me to this url.
the pip document perhaps change the code.