JaxGaussianProcesses/GPJax

feat: Kronecker kernel + computation

daniel-dodd opened this issue · 2 comments

Following the integration of CoLA (#370). It would be great to add a Konecker Kernel + computation.

from gpjax.kernels import CombinationKernel, RBF

# Code inherits from CombinationKernel or ProductKernel
class Kronecker(CombinationKernel):
        kernels: Sequence[AbstractKernel] = None
        compute_engine: AbstractKernelComputation = static_field(KronckerKernelComputation())
        ...
       
       def __post_init__(self):
         ... # check  influence of the Kernel across each dimension 
             # or set of dimensions is separable (active dims)

# Demo
k1 = RBF(active_dims=0)
k2 = RBF(active_dims=1)
kron_kern = Kronecker([k1, k2])