Option to place arrows below w.r.t `atop` in diagram
Opened this issue ยท 5 comments
The connect_
family of functions do diagram + arrow
as opposed to arrow + diagram
.
Line 66 in f4f4582
Line 94 in f4f4582
Line 40 in f4f4582
Since atop
or +
is not commutative, these have different output renderings and I've found a few use-cases for the latter. I think both are useful, and this is a feature request for specifying arrow + diagram
or diagram + arrow
through a switch.
To motivate the use-case with an example, left below is a reproduction of MultiHeadAttention
from the original paper (right).
One process declaratively specifying this diagram involves stacking ScaledDotProductAttention
(SDPA) for each head and the having Split
branch out to each heads. connect_outside
works for me here, but since I'm doing the connect after the stack is created - the arrows appear above, not honoring the depth dimension.
I have found some success for a custom Trail
based arrow - see arrow_outside_up_free
and arrow_outside_up
.
PS: I don't mean to pile on the issues here, I'm happy to help and bring in a PR myself following consensus with some guidance. Also if there are alternative recommended routes with existing primitives that solves the above problem - open to trying those as well. Thanks for building and maintaining the library!
Hello Jerin! I got some time to read your issue. For this particular example, I think the easiest way to achieve the correct layering is by drawing the arrows once you draw the other head
components. That is, move this part
dia = arrow_outside_up(
dia,
f"linear_{x}_" + str(i),
"sdpa_" + str(i),
left=True,
)
in the head
function:
dia = vcat(center([sdpa, hcat([q, k, v], hspace / 2)]), vspace)
for x in "vkq":
dia = arrow_outside_up(
dia,
f"linear_{x}_{i}",
f"sdpa_{i}",
left=True,
)
dia = dia.center_xy().translate(dx, dy).fill_opacity(opacity)
return dia
This change does seem to yield the desired output:
Otherwise, even with the left
flag it seems to me that we cannot reach a satisfactory result: the arrow from linear_v_2
to sdpa_2
is behind the linear_v_0
and linear_v_1
blocks (even though this is not immediately obvious due to the transparency of the linear
blocks). So maybe a more general solution is needed? I recall that @srush also suggested a similar idea as a possible improvement, but I don't know if he had any approach in mind:
Render arrows in between two already plotted values in Z-space (Not sure if this is possible in a functional system)
Do we know what Haskell diagrams does here? I can try to figure it out.
Do we know what Haskell diagrams does here?
I don't recall seeing anything too similar in the Haskell codebase ๐ค Maybe the idea of delayed composition is related, as it allows reordering the components of a diagram after their creation, but I don't feel that this is much easier to use than simply creating the diagram in the "right" order from the get-go. Otherwise, for arrow connections, the connect
-like functions also seem to use a predefined order (via atop
).
Looks like my large-text created some confusion and more errors ๐. Please allow me to clarify.
-
A flag that gives the user the ability to change the order of
atop
can be a backwards compatible change, with default beingdiagram + arrow
and an option that switches toarrow + diagram
. As forarrow_outside_up
andarrow_outside_free
, I was pointing at theleft
flag swapping order foratop
. I argue this option on theconnect
family of functions should be a net improvement without breaking anything existing. (If there's some objective to keep close compatibility with haskell-diagrams, then this might be a problem). -
The remaining is me arguing for the use-case of such a flag. To clarify, I'm trying to use a switch to create
Split
->Linear
(branch) connect arrows. In this case, I want to plot the z-stack ofLinear
s first and thenvcat
them properly withSplit
, afterwards connect.connect_outside
orconnect
should work for me here, except in the current state (without the switch) arrows come on top when connects happen after the z-stack (Linear
) andSplit
are declared (I see this wasn't clear in writing, apologies ๐).
I'm not sure if I follow your response:
That is, move this part ... in the head function
Both permalinks are same. Assuming you meant move this function to head. I see there's an opacity difference that indicates the arrows z-index (properly) between the former code and your suggestion and is clearly the better way to render. Thanks for pointing out. Let me know if I'm missing something.
Thanks for the clarification, Jerin! I understand your motivation for the flag, but while it solves the "split" to "linear" use case, it doesn't seem to solve the "linear" to "scaled dot product attention" example.
Namely, we cannot achieve something like this
with neither the original connect, which would yield
nor with the reversed-order version, which would yield
Ideally, I would prefer a solution that would address this sort of cases as well. I don't know, Sasha, if you have any opinion on this matter.
Otherwise, here is a hack that reverses the order of elements after a connect:
def connect_outside_reverse(*args, **kwargs):
dia = connect_outside(*args, **kwargs)
return dia.diagram2 + dia.diagram1
or written as combinator:
from chalk.core import Compose
def zswap(diagram):
if isinstance(diagram, Compose):
return diagram.diagram2 + diagram.diagram1
else:
return diagram
zswap(connect_outside(dia, f"bot {i}", f"top {i}"))
Both permalinks are same. Assuming you meant move this function to head.
Oops, yes, you are right! Here is the corresponding code that generated that figure.