chalk-diagrams/chalk

Option to place arrows below w.r.t `atop` in diagram

Opened this issue ยท 5 comments

The connect_ family of functions do diagram + arrow as opposed to arrow + diagram.

return dia + arrow_between(ps, pe, style)

return dia + arrow_between(ps, pe, style)

return dia + arrow_between(ps, pe, style)

Since atop or + is not commutative, these have different output renderings and I've found a few use-cases for the latter. I think both are useful, and this is a feature request for specifying arrow + diagram or diagram + arrow through a switch.

To motivate the use-case with an example, left below is a reproduction of MultiHeadAttention from the original paper (right).

One process declaratively specifying this diagram involves stacking ScaledDotProductAttention (SDPA) for each head and the having Split branch out to each heads. connect_outside works for me here, but since I'm doing the connect after the stack is created - the arrows appear above, not honoring the depth dimension.

I have found some success for a custom Trail based arrow - see arrow_outside_up_free and arrow_outside_up.

PS: I don't mean to pile on the issues here, I'm happy to help and bring in a PR myself following consensus with some guidance. Also if there are alternative recommended routes with existing primitives that solves the above problem - open to trying those as well. Thanks for building and maintaining the library!

Hello Jerin! I got some time to read your issue. For this particular example, I think the easiest way to achieve the correct layering is by drawing the arrows once you draw the other head components. That is, move this part

            dia = arrow_outside_up(
                dia,
                f"linear_{x}_" + str(i),
                "sdpa_" + str(i),
                left=True,
            )

in the head function:

        dia = vcat(center([sdpa, hcat([q, k, v], hspace / 2)]), vspace)
        for x in "vkq":
            dia = arrow_outside_up(
                dia,
                f"linear_{x}_{i}",
                f"sdpa_{i}",
                left=True,
            )
        dia = dia.center_xy().translate(dx, dy).fill_opacity(opacity)
        return dia

This change does seem to yield the desired output:
Screenshot 2023-09-11 at 15 52 10

Otherwise, even with the left flag it seems to me that we cannot reach a satisfactory result: the arrow from linear_v_2 to sdpa_2 is behind the linear_v_0 and linear_v_1 blocks (even though this is not immediately obvious due to the transparency of the linear blocks). So maybe a more general solution is needed? I recall that @srush also suggested a similar idea as a possible improvement, but I don't know if he had any approach in mind:

Render arrows in between two already plotted values in Z-space (Not sure if this is possible in a functional system)

Do we know what Haskell diagrams does here? I can try to figure it out.

Do we know what Haskell diagrams does here?

I don't recall seeing anything too similar in the Haskell codebase ๐Ÿค” Maybe the idea of delayed composition is related, as it allows reordering the components of a diagram after their creation, but I don't feel that this is much easier to use than simply creating the diagram in the "right" order from the get-go. Otherwise, for arrow connections, the connect-like functions also seem to use a predefined order (via atop).

Looks like my large-text created some confusion and more errors ๐Ÿ˜“. Please allow me to clarify.

  1. A flag that gives the user the ability to change the order of atop can be a backwards compatible change, with default being diagram + arrow and an option that switches to arrow + diagram. As for arrow_outside_up and arrow_outside_free, I was pointing at the left flag swapping order for atop. I argue this option on the connect family of functions should be a net improvement without breaking anything existing. (If there's some objective to keep close compatibility with haskell-diagrams, then this might be a problem).

  2. The remaining is me arguing for the use-case of such a flag. To clarify, I'm trying to use a switch to create Split -> Linear (branch) connect arrows. In this case, I want to plot the z-stack of Linears first and then vcat them properly with Split, afterwards connect. connect_outside or connect should work for me here, except in the current state (without the switch) arrows come on top when connects happen after the z-stack (Linear) and Split are declared (I see this wasn't clear in writing, apologies ๐Ÿ˜“).

image

I'm not sure if I follow your response:

That is, move this part ... in the head function

Both permalinks are same. Assuming you meant move this function to head. I see there's an opacity difference that indicates the arrows z-index (properly) between the former code and your suggestion and is clearly the better way to render. Thanks for pointing out. Let me know if I'm missing something.

Thanks for the clarification, Jerin! I understand your motivation for the flag, but while it solves the "split" to "linear" use case, it doesn't seem to solve the "linear" to "scaled dot product attention" example.

Namely, we cannot achieve something like this
Screenshot 2023-09-17 at 00 16 35
with neither the original connect, which would yield
Screenshot 2023-09-17 at 00 17 16
nor with the reversed-order version, which would yield
Screenshot 2023-09-17 at 00 18 34

Ideally, I would prefer a solution that would address this sort of cases as well. I don't know, Sasha, if you have any opinion on this matter.

Otherwise, here is a hack that reverses the order of elements after a connect:

def connect_outside_reverse(*args, **kwargs):
    dia = connect_outside(*args, **kwargs)
    return dia.diagram2 + dia.diagram1

or written as combinator:

from chalk.core import Compose

def zswap(diagram):
    if isinstance(diagram, Compose):
        return diagram.diagram2 + diagram.diagram1
    else:
        return diagram

zswap(connect_outside(dia, f"bot {i}", f"top {i}"))

Both permalinks are same. Assuming you meant move this function to head.

Oops, yes, you are right! Here is the corresponding code that generated that figure.