scala/scala3

Context Lambdas change identity on each reference; no persistent identity

Opened this issue · 4 comments

Compiler version

3.6.3

Minimized code

With ordinary lambdas, they have distinct values on the runtime and you can trivially compare a lambda against itself, put it in collections and so on.

val f: (x: Int) => Int = x => 3 + x

f == f // true
f // Playground$$$Lambda$15092/0x00007f26d3fc7000@be515d5
f // Playground$$$Lambda$15092/0x00007f26d3fc7000@be515d5

var fns = Set[(x: Int) => Int]()
for i <- 1 to 10 do
  fns = fns + f

fns.size // 1

The last part using a set is a good demonstration for verifying correct behavior. Sets by definition do not allow duplicates, so 1 is what we should expect to see. Anything else would be extremely counter-intuitive in a language where functions are values.

However, lambdas with context parameters behave very differently.

val f: (x: Int) ?=> Int = 3 + x

var fns = Set[(x: Int) ?=> Int]()
for i <- 1 to 10 do
  fns = fns + f

fns.size // 10

Adding the same context lambda to a set 10 times results in a set of 10 elements.
And if we ask if that lambda is itself, it will evaluate to false, and show that Scala is actually changing the value identity in the runtime with every reference of f

val f: (x: Int) ?=> Int = 3 + x

var seq = Seq[(x: Int) ?=> Int]()

def first(seq: Seq[Any]) = seq(0)

seq = f +: seq
val a = first(seq)

seq = f +: seq
val b = first(seq)

a == b // false 

first(seq) == first(seq) // true

a // Playground$$$Lambda$14026/0x00007f26d3d09ab8@490ed173
b // Playground$$$Lambda$14197/0x00007f26d3d46000@15165343

Some explanation: referencing context lambdas and passing them around without invoking them is very difficult in Scala, as the moment you refer to them the Scala compiler will immediately search for the required implicits in the current scope in attempting an invocation. To work around that for this experiment, we can put the context function into a sequence of context functions, cast that sequence to Seq[Any], then retrieve index 0 from it so Scala does not know it's a context function.

Another experiment someone showed

val f : Int ?=> Int = ???
def equal[T](a: T, b: T): Boolean = {
  a == b 
}

equal[Int?=>Int](f,f) // false

More said on a thread I made here https://users.scala-lang.org/t/preventing-invocation-of-context-functions-referenced/10625

Why does Scala change the identifier for the context lambda each time? Is it a new wrapper? Can we avoid this?

It's unintuitive behavior, and it is currently bringing my project where I ran into this bug to a complete stop with no obvious solution, as my use case requires me receiving a context lambda and checking if it is the same context lambda I had stored previously, where a simple == check would do the job but this looks impossible coincidentally.

Not sure if this is a question or a complaint, but
https://dotty.epfl.ch/docs/reference/contextual/context-functions-spec.html
linked from
https://dotty.epfl.ch/docs/reference/contextual/context-functions.html
but the link is broken from
https://www.scala-lang.org/files/archive/api/current/docs/docs/reference/contextual/context-functions.html

Image

It is explicitly a new instance:

Image

The forum discussion also says it's like by-name params, so I think the analogy is useful.

But for context functions, the difference lies in syntax, and not in the type of the expression.

It's helpful to link to scastie, because scastie snippets don't always correspond to the overt code.

If the puzzler is what does this print? I would have gasped with the rest of the audience.

val f: (x: Int) ?=> Int = 3 + x

var fns = Set[(x: Int) ?=> Int]()

class C:
  def g =
    for i <- 1 to 10 do
      fns = fns + f
    println(fns.size)

object X:
  val f: (x: Int) ?=> Int = 3 + x
  var fns = Set[(x: Int) ?=> Int]()
  for i <- 1 to 10 do
    fns = fns + f
  println(fns.size)

@main def test =
  X
  C().g

Thinking of you, Puzzler Guy!

The answer is

➜  scala-cli run --server=false -S 3.6.4 i22767.scala
10
1

The answer is

➜  scala-cli run --server=false -S 3.6.4 i22767.scala
10
1

Sorry could you elaborate on why X and C().g behaves differently? As they are both context functions I'm surprised that you're getting 1, I would expect 10 for both.

An implication of this design choice seems to be that it is impossible to later compare if one context lambda is the same context lambda you were storing previously. Which is very strange, because clearly you can compare them when not needing to reference them by name, as this snippet proved

first(seq) == first(seq) // true

Can you think of any other ways? To allow for something like this code to work?

object LambdaContainer:
  var fns = Seq[(x: Int) ?=> Int]()
  def contains(f: (x: Int) ?=> Int) =
    if f is inside of fns then true else false

LambdaContainer.fns = LambdaContainer.fns :+ f  
LambdaContainer.contains(f) // true

The difference is that after lambdaLift, the context function in the local body of a method is static, but the same context function in the initializer of the module is not.

    private final <static> def g$$anonfun$1$$anonfun$1(using x: Int): Int =
      Int.unbox(Y.f().apply(Int.box(x)))
    private final def g$$anonfun$1(i: Int): Unit =
      this fns_=
        (Y.fns() +
          {
            closure(<empty>.this.g$$anonfun$1$$anonfun$1:
              scala.runtime.java8.JFunction1$mcII$sp)
          }
        ).asInstanceOf[scala.collection.immutable.Set]

versus

    private final def $init$$$anonfun$2$$anonfun$1(using x: Int): Int =
      Int.unbox(X.f().apply(Int.box(x)))
    private final def $init$$$anonfun$2(i: Int): Unit =
      this fns_=
        (X.fns() +
          {
            closure(this.$init$$$anonfun$2$$anonfun$1:
              scala.runtime.java8.JFunction1$mcII$sp)
          }
        ).asInstanceOf[scala.collection.immutable.Set]

for

object X:
  val f: (x: Int) ?=> Int = 3 + x

  var fns = Set[(x: Int) ?=> Int]()

  for i <- 1 to 10 do
    fns = fns + f
  println(fns.size)

object Y:
  val f: (x: Int) ?=> Int = 3 + x

  var fns = Set[(x: Int) ?=> Int]()

  def g() =
    for i <- 1 to 10 do
      fns = fns + f
    println(fns.size)