使用BrainCog模拟RMSNorm,最终输出为nan
Opened this issue · 0 comments
LumenScope commented
class SNN_RMSNorm(nn.Module):
def __init__(self, max_length = 128, hidden_size=4096,node=LIAFNode, threshold=0.5, eps=1e-6):
super().__init__()
self.eps = eps
self.rms_neuron = node(act_fun='LeakyReLU', threshold=threshold)
self.weight_neuron = node(act_fun='ReLU', threshold=threshold)
self.weight = nn.Parameter(torch.ones(hidden_size,hidden_size))
self.rms_connection = CustomLinear(torch.ones(1,hidden_size))
self.weight_connection = CustomLinear(self.weight)
def forward(self, x):
x_sqr = x ** 2
x_rms = x_sqr.mean(-1, keepdim=True)
s_rms = self.rms_neuron(self.rms_connection(x_rms))
rms_out = torch.rsqrt(s_rms + self.eps)
s_scale = self.weight_neuron(self.weight_connection(rms_out))
return s_scale
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
return output * self.weight
以上为我定义的SNN化RMSNorm和原始RMSNorm函数,以下为SNN_RMSNorm前向传播的输出,虽然维度shape经过我的处理达到了一致,但是输出如下:
tensor([[[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]],
[[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]]], device='cuda:0',
grad_fn=<StackBackward0>)
torch.Size([2, 128, 4096])
以下为全部代码:
from torchvision import transforms
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_, DropPath
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfg
from spikingjelly.clock_driven.neuron import MultiStepLIFNode, MultiStepParametricLIFNode
from transformers import CLIPProcessor, CLIPModel
from accelerate import Accelerator
from dataclasses import dataclass
from typing import Optional, Tuple
import fairscale.nn.model_parallel.initialize as fs_init
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, ParallelEmbedding, RowParallelLinear
import numpy as np
import os
import sys
from torch.nn import Parameter
import abc
from abc import ABC
from einops import rearrange, repeat
accelerator = Accelerator()
@dataclass
class ModelArgs:
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
n_kv_heads: Optional[int] = None
vocab_size: int = -1 # defined later by tokenizer
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None
norm_eps: float = 1e-5
max_batch_size: int = 32
max_seq_len: int = 2048
class CustomLinear(nn.Module):
"""
用户自定义连接 通常stdp的计算
"""
def __init__(self, weight, mask=None):
super().__init__()
self.weight = nn.Parameter(weight, requires_grad=True)
self.mask = mask
def forward(self, x: torch.Tensor):
"""
:param x:输入 x.shape = [N ]
"""
#
# ret.shape = [C]
return x.matmul(self.weight)
def update(self, dw):
"""
:param dw:权重更新量
"""
with torch.no_grad():
if self.mask is not None:
dw *= self.mask
self.weight.data += dw
class STDP(nn.Module):
"""
STDP learning rule
"""
def __init__(self, node, connection, decay=0.99):
"""
:param node:node神经元类型实例如IFNode LIFNode
:param connection:连接 类的实例 里面只能有一个操作
"""
super().__init__()
self.node = node
self.connection = connection
self.trace = None
self.decay = decay
def forward(self, x):
"""
计算前向传播过程
:return:s是脉冲 dw更新量
"""
x = x.clone().detach()
i = self.connection(x)
with torch.no_grad():
s = self.node(i)
i.data += s - i.data
trace = self.cal_trace(x)
x.data += trace - x.data
dw = torch.autograd.grad(
outputs=i, inputs=self.connection.weight, grad_outputs=i)
return s, dw
def cal_trace(self, x):
"""
计算trace
"""
if self.trace is None:
self.trace = Parameter(x.clone().detach(), requires_grad=False)
else:
self.trace *= self.decay
self.trace += x
return self.trace.detach()
def reset(self):
"""
重置
"""
self.trace = None
def heaviside(x):
return (x >= 0.).to(x.dtype)
class quadratic_gate(torch.autograd.Function):
"""
使用 quadratic_gate 作为代理梯度函数
对应的原函数为:
.. math::
g(x) =
\\begin{cases}
0, & x < -\\frac{1}{\\alpha} \\\\
-\\frac{1}{2}\\alpha^2|x|x + \\alpha x + \\frac{1}{2}, & |x| \\leq \\frac{1}{\\alpha} \\\\
1, & x > \\frac{1}{\\alpha} \\\\
\\end{cases}
反向传播的函数为:
.. math::
g'(x) =
\\begin{cases}
0, & |x| > \\frac{1}{\\alpha} \\\\
-\\alpha^2|x|+\\alpha, & |x| \\leq \\frac{1}{\\alpha}
\\end{cases}
"""
@staticmethod
def forward(ctx, x, alpha):
if x.requires_grad:
mask_zero = (x.abs() > 1 / alpha)
grad_x = -alpha * alpha * x.abs() + alpha
grad_x.masked_fill_(mask_zero, 0)
ctx.save_for_backward(grad_x)
return x.gt(0.).float()
@staticmethod
def backward(ctx, grad_output):
grad_x = None
if ctx.needs_input_grad[0]:
grad_x = grad_output * ctx.saved_tensors[0]
return grad_x, None
class SurrogateFunctionBase(nn.Module):
"""
Surrogate Function 的基类
:param alpha: 为一些能够调控函数形状的代理函数提供参数.
:param requires_grad: 参数 ``alpha`` 是否需要计算梯度, 默认为 ``False``
"""
def __init__(self, alpha, requires_grad=True):
super().__init__()
self.alpha = nn.Parameter(
torch.tensor(alpha, dtype=torch.float),
requires_grad=requires_grad)
@staticmethod
def act_fun(x, alpha):
"""
:param x: 膜电位的输入
:param alpha: 控制代理梯度形状的变量, 可以为 ``NoneType``
:return: 激发之后的spike, 取值为 ``[0, 1]``
"""
raise NotImplementedError
def forward(self, x):
"""
:param x: 膜电位输入
:return: 激发之后的spike
"""
return self.act_fun(x, self.alpha)
'''
sigmoid surrogate function.
'''
class QGateGrad(SurrogateFunctionBase):
def __init__(self, alpha=2., requires_grad=False):
super().__init__(alpha, requires_grad)
@staticmethod
def act_fun(x, alpha):
return quadratic_gate.apply(x, alpha)
class relu_like(torch.autograd.Function):
@staticmethod
def forward(ctx, x, alpha):
if x.requires_grad:
ctx.save_for_backward(x, alpha)
return heaviside(x)
@staticmethod
def backward(ctx, grad_output):
grad_x, grad_alpha = None, None
x, alpha = ctx.saved_tensors
if ctx.needs_input_grad[0]:
grad_x = grad_output * x.gt(0.).float() * alpha
if ctx.needs_input_grad[1]:
grad_alpha = (grad_output * F.relu(x)).sum()
return grad_x, grad_alpha
class RoundGrad(nn.Module):
def __init__(self, **kwargs):
super(RoundGrad, self).__init__()
self.act = nn.Hardtanh(-.5, 4.5)
def forward(self, x):
x = self.act(x)
return x.ceil() + x - x.detach()
class backeigate(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input.gt(0.).float()
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
temp = abs(input) < 0.5
return grad_input * temp.float()
class BackEIGateGrad(SurrogateFunctionBase):
def __init__(self, alpha=2., requires_grad=False):
super().__init__(alpha, requires_grad)
@staticmethod
def act_fun(x, alpha):
return backeigate.apply(x)
class ei(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return torch.sign(input).float()
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
temp = abs(input) < 0.5
return grad_input * temp.float()
class BaseNode(nn.Module, abc.ABC):
"""
神经元模型的基类
:param threshold: 神经元发放脉冲需要达到的阈值
:param v_reset: 静息电位
:param dt: 时间步长
:param step: 仿真步
:param requires_thres_grad: 是否需要计算对于threshold的梯度, 默认为 ``False``
:param sigmoid_thres: 是否使用sigmoid约束threshold的范围搭到 [0, 1], 默认为 ``False``
:param requires_fp: 是否需要在推理过程中保存feature map, 需要消耗额外的内存和时间, 默认为 ``False``
:param layer_by_layer: 是否以一次性计算所有step的输出, 在网络模型较大的情况下, 一般会缩短单次推理的时间, 默认为 ``False``
:param n_groups: 在不同的时间步, 是否使用不同的权重, 默认为 ``1``, 即不分组
:param mem_detach: 是否将上一时刻的膜电位在计算图中截断
:param args: 其他的参数
:param kwargs: 其他的参数
"""
def __init__(self,
threshold=.5,
v_reset=0.,
dt=1.,
step=8,
requires_thres_grad=False,
sigmoid_thres=False,
requires_fp=False,
layer_by_layer=False,
n_groups=1,
*args,
**kwargs):
super(BaseNode, self).__init__()
self.threshold = Parameter(torch.tensor(
threshold), requires_grad=requires_thres_grad)
self.sigmoid_thres = sigmoid_thres
self.mem = 0.
self.spike = 0.
self.dt = dt
self.feature_map = []
self.mem_collect = []
self.requires_fp = requires_fp
self.v_reset = v_reset
self.step = step
self.layer_by_layer = layer_by_layer
self.groups = n_groups
self.mem_detach = kwargs['mem_detach'] if 'mem_detach' in kwargs else False
self.requires_mem = kwargs['requires_mem'] if 'requires_mem' in kwargs else False
@abc.abstractmethod
def calc_spike(self):
"""
通过当前的mem计算是否发放脉冲,并reset
:return: None
"""
pass
def integral(self, inputs):
"""
计算由当前inputs对于膜电势的累积
:param inputs: 当前突触输入电流
:type inputs: torch.tensor
:return: None
"""
pass
def get_thres(self):
return self.threshold if not self.sigmoid_thres else self.threshold.sigmoid()
def rearrange2node(self, inputs):
if self.groups != 1:
if len(inputs.shape) == 4:
outputs = rearrange(
inputs, 'b (c t) w h -> t b c w h', t=self.step)
elif len(inputs.shape) == 2:
outputs = rearrange(inputs, 'b (c t) -> t b c', t=self.step)
else:
raise NotImplementedError
elif self.layer_by_layer:
if len(inputs.shape) == 4:
outputs = rearrange(
inputs, '(t b) c w h -> t b c w h', t=self.step)
elif len(inputs.shape) == 3:
outputs = rearrange(
inputs, '(t b) n c -> t b n c', t=self.step)
elif len(inputs.shape) == 2:
outputs = rearrange(inputs, '(t b) c -> t b c', t=self.step)
else:
raise NotImplementedError
else:
outputs = inputs
return outputs
def rearrange2op(self, inputs):
if self.groups != 1:
if len(inputs.shape) == 5:
outputs = rearrange(inputs, 't b c w h -> b (c t) w h')
elif len(inputs.shape) == 3:
outputs = rearrange(inputs, ' t b c -> b (c t)')
else:
raise NotImplementedError
elif self.layer_by_layer:
if len(inputs.shape) == 5:
outputs = rearrange(inputs, 't b c w h -> (t b) c w h')
elif len(inputs.shape) == 4:
outputs = rearrange(inputs, ' t b n c -> (t b) n c')
elif len(inputs.shape) == 3:
outputs = rearrange(inputs, ' t b c -> (t b) c')
else:
raise NotImplementedError
else:
outputs = inputs
return outputs
def forward(self, inputs):
"""
torch.nn.Module 默认调用的函数,用于计算膜电位的输入和脉冲的输出
在```self.requires_fp is True``` 的情况下,可以使得```self.feature_map```用于记录trace
:param inputs: 当前输入的膜电位
:return: 输出的脉冲
"""
if self.layer_by_layer or self.groups != 1:
inputs = self.rearrange2node(inputs)
outputs = []
for i in range(self.step):
if self.mem_detach and hasattr(self.mem, 'detach'):
self.mem = self.mem.detach()
self.spike = self.spike.detach()
self.integral(inputs[i])
self.calc_spike()
if self.requires_fp is True:
self.feature_map.append(self.spike)
if self.requires_mem is True:
self.mem_collect.append(self.mem)
outputs.append(self.spike)
outputs = torch.stack(outputs)
outputs = self.rearrange2op(outputs)
return outputs
else:
if self.mem_detach and hasattr(self.mem, 'detach'):
self.mem = self.mem.detach()
self.spike = self.spike.detach()
self.integral(inputs)
self.calc_spike()
if self.requires_fp is True:
self.feature_map.append(self.spike)
if self.requires_mem is True:
self.mem_collect.append(self.mem)
return self.spike
def n_reset(self):
"""
神经元重置,用于模型接受两个不相关输入之间,重置神经元所有的状态
:return: None
"""
self.mem = self.v_reset
self.spike = 0.
self.feature_map = []
self.mem_collect = []
def get_n_attr(self, attr):
if hasattr(self, attr):
return getattr(self, attr)
else:
return None
def set_n_warm_up(self, flag):
"""
一些训练策略会在初始的一些epoch,将神经元视作ANN的激活函数训练,此为设置是否使用该方法训练
:param flag: True:神经元变为激活函数, False:不变
:return: None
"""
self.warm_up = flag
def set_n_threshold(self, thresh):
"""
动态设置神经元的阈值
:param thresh: 阈值
:return:
"""
self.threshold = Parameter(torch.tensor(
thresh, dtype=torch.float), requires_grad=False)
def set_n_tau(self, tau):
"""
动态设置神经元的衰减系数,用于带Leaky的神经元
:param tau: 衰减系数
:return:
"""
if hasattr(self, 'tau'):
self.tau = Parameter(torch.tensor(
tau, dtype=torch.float), requires_grad=False)
else:
raise NotImplementedError
class LIFNode(BaseNode):
"""
Leaky Integrate and Fire
:param threshold: 神经元发放脉冲需要达到的阈值
:param v_reset: 静息电位
:param dt: 时间步长
:param step: 仿真步
:param tau: 膜电位时间常数, 用于控制膜电位衰减
:param act_fun: 使用surrogate gradient 对梯度进行近似, 默认为 ``surrogate.AtanGrad``
:param requires_thres_grad: 是否需要计算对于threshold的梯度, 默认为 ``False``
:param sigmoid_thres: 是否使用sigmoid约束threshold的范围搭到 [0, 1], 默认为 ``False``
:param requires_fp: 是否需要在推理过程中保存feature map, 需要消耗额外的内存和时间, 默认为 ``False``
:param layer_by_layer: 是否以一次性计算所有step的输出, 在网络模型较大的情况下, 一般会缩短单次推理的时间, 默认为 ``False``
:param n_groups: 在不同的时间步, 是否使用不同的权重, 默认为 ``1``, 即不分组
:param args: 其他的参数
:param kwargs: 其他的参数
"""
def __init__(self, threshold=0.5, tau=2., act_fun=QGateGrad, *args, **kwargs):
super().__init__(threshold, *args, **kwargs)
self.tau = tau
if isinstance(act_fun, str):
act_fun = eval(act_fun)
self.act_fun = act_fun(alpha=2., requires_grad=False)
# self.threshold = threshold
# print(threshold)
# print(tau)
def integral(self, inputs):
self.mem = self.mem + (inputs - self.mem) / self.tau
def calc_spike(self):
self.spike = self.act_fun(self.mem - self.threshold)
self.mem = self.mem * (1 - self.spike.detach())
class LIAFNode(BaseNode):
"""
Leaky Integrate and Analog Fire (LIAF), Reference: https://ieeexplore.ieee.org/abstract/document/9429228
与LIF相同, 但前传的是膜电势, 更新沿用阈值和膜电势
:param act_fun: 前传使用的激活函数 [ReLU, SeLU, LeakyReLU]
:param threshold_related: 阈值依赖模式,若为"True"则 self.spike = act_fun(mem-threshold)
:note that BaseNode return self.spike, and here self.spike is analog value.
"""
def __init__(self, spike_act=BackEIGateGrad(), act_fun="SELU", threshold=0.5, tau=2., threshold_related=True, *args, **kwargs):
super().__init__(threshold, *args, **kwargs)
if isinstance(act_fun, str):
act_fun = eval("nn." + act_fun + "()")
self.tau = tau
self.act_fun = act_fun
self.spike_act = spike_act
self.threshold_related = threshold_related
def integral(self, inputs):
self.mem = self.mem + (inputs - self.mem) / self.tau
def calc_spike(self):
if self.threshold_related:
spike_tmp = self.act_fun(self.mem - self.threshold)
else:
spike_tmp = self.act_fun(self.mem)
self.spike = self.spike_act(self.mem - self.threshold)
self.mem = self.mem * (1 - self.spike)
self.spike = spike_tmp
class SNN_RMSNorm(nn.Module):
def __init__(self, max_length = 128, hidden_size=4096,node=LIAFNode, threshold=0.5, eps=1e-6):
super().__init__()
self.eps = eps
self.rms_neuron = node(act_fun='LeakyReLU', threshold=threshold)
self.weight_neuron = node(act_fun='ReLU', threshold=threshold)
self.weight = nn.Parameter(torch.ones(hidden_size,hidden_size))
self.rms_connection = CustomLinear(torch.ones(1,hidden_size))
self.weight_connection = CustomLinear(self.weight)
def forward(self, x):
x_sqr = x ** 2
x_rms = x_sqr.mean(-1, keepdim=True)
s_rms = self.rms_neuron(self.rms_connection(x_rms))
rms_out = torch.rsqrt(s_rms + self.eps)
s_scale = self.weight_neuron(self.weight_connection(rms_out))
return s_scale