Look into mutual recursion?
bbarker opened this issue · 8 comments
Hi! Thanks for the great crate!
I was excited to see this went so well. In my benchmark PR, I outlined how mutual recursion doesn't really work. I may come back to look at why, but probably not anytime soon. I just wanted to open up the issue to discuss, or as a placeholder for later.
Thanks,
Thank you for your contributions!
Mutual recursion is an interesting problem. I'm not sure how to do it without some level of indirection in Rust.
For example:
#[tailcall]
fn is_even(x: u128) -> bool {
if x > 0 {
is_odd(x - 1)
} else {
true
}
}
#[tailcall]
fn is_odd(x: u128) -> bool {
if x > 0 {
is_even(x - 1)
} else {
false
}
}
...could expand to something like...
const IS_EVEN_IMPL: trampoline::StepFn<u128, bool> = trampoline::StepFn(&{
#[inline(always)]
|x| {
if x > 0 {
trampoline::Recurse(IS_ODD_IMPL, x - 1)
} else {
trampoline::Finish(true)
}
}
});
fn is_even(x: u128) -> bool {
IS_EVEN_IMPL.call(x)
}
const IS_ODD_IMPL: trampoline::StepFn<u128, bool> = trampoline::StepFn(&{
#[inline(always)]
|x| {
if x > 0 {
trampoline::Recurse(IS_EVEN_IMPL, x - 1)
} else {
trampoline::Finish(false)
}
}
});
fn is_odd(x: u128) -> bool {
IS_ODD_IMPL.call(x)
}
...with this support code...
mod trampoline {
pub enum Next<'a, Input, Output> {
Recurse(StepFn<'a, Input, Output>, Input),
Finish(Output),
}
pub use Next::*;
#[repr(transparent)]
pub struct StepFn<'a, Input, Output>(
pub &'a dyn Fn(Input) -> Next<'a, Input, Output>
);
impl<Input, Output> StepFn<'_, Input, Output> {
#[inline(always)]
pub fn call(self, input: Input) -> Output {
let mut next = Recurse(self, input);
loop {
match next {
Recurse(step_fn, input) => {
next = step_fn.0(input);
},
Finish(output) => {
break output;
},
}
}
}
}
}
This will only use a single stack frame, but it comes at the cost of passing function pointers around (which is likely opaque to most optimizations) and vtables.
I referenced the wrong issue in #12 🙈
First off thanks for making this library. I was thinking about making a runtime in Rust for a functional language. Your library could be pretty useful for that.
I was wondering though, how would you deal with mutually recursive functions with different types? Obviously the return types have to be the same but they could have different arity or types of arguments.
I think in this case you'll need to switch to an encoding where Next
is generated per group of mutually recursive functions so that it has an alternative for each function in the group.
Like this:
pub enum Next<'a, Output> {
RecurseIsEven(StepFn<'a, u128, Output>, u128),
RecurseIsOdd(StepFn<'a, u128, Output>, u128),
Finish(Output),
}
However, now when we get to call
we see that StepFn
isn't necessary at all. Let's look at the is_even
entry point for a second:
pub fn call(self, input: u128) -> Output {
let mut next = (IS_EVEN_IMPL)(input);
loop {
match next {
RecurseIsEven(input) => {
next = (IS_EVEN_IMPL)(input);
},
RecurseIsOdd(input) => {
next = (IS_ODD_IMPL)(input);
},
Finish(output) => {
break output;
},
}
}
}
Where the FOO_IMPL
s are just something simple like this:
fn IS_EVEN_IMPL(x: u128) -> bool {
if x > 0 {
trampoline::RecurseIsOdd(x - 1)
} else {
trampoline::Finish(true)
}
}
Because there will be a different implementation of Next
and call
for each block of mutually recursive definitions, we'll need to generate a new name somewhere. For the sake of argument, I've modeled that with an inner module with a unique name, say trampoline1234
.
Putting it altogether it might look something like this:
fn is_even(x: u128) -> bool {
trampoline1234::RecurseIsEven(x).call()
}
fn is_odd(x: u128) -> bool {
trampoline1234::RecurseIsOdd(x).call()
}
// Inner module with unique name
mod trampoline1234 {
pub enum Next {
RecurseIsEven(u128),
RecurseIsOdd(u128)
Finish(bool),
}
pub use Next::*;
impl Next {
#[inline(always)]
pub fn call(self) -> bool {
let mut next = self;
loop {
match next {
RecurseIsEven(input) => {
next = IS_EVEN_IMPL(input);
},
RecurseIsOdd(input) => {
next = IS_ODD_IMPL(input);
},
Finish(output) => {
break output;
},
}
}
}
}
// I think it probably makes sense to put these in this inner module too to help with namespacing
fn IS_EVEN_IMPL(x: u128) -> bool {
if x > 0 {
trampoline1234::RecurseIsOdd(x - 1)
} else {
trampoline1234::Finish(true)
}
}
fn IS_ODD_IMPL(x: u128) -> bool {
if x > 0 {
trampoline1234::RecurseIsEven(x - 1)
} else {
trampoline1234::Finish(true)
}
}
}
I don't really know the rust plugins/macros well enough to implement this stuff, but that transformation should at least capture most cases of mutually recursive functions. Although, it probably wouldn't work with mutually recursive function pointers/closures/etc. You could also use this encoding for the existing tail recursive case in trampoline. The existing encoding is probably slightly better for code size but after optimizations I would imagine they're pretty similar. If code size was a concern, you could play with removing the always inline pragma.
The other thing that may be difficult (or need different annotations) is detecting the groups of mutually recursive functions in the first place.
Great points, @dagit !
I think your approach of using a variant in place of a function for mutual recursion would work in general. The only issue I see is making it work with Rust macros. To me, it seems you need some facility to collect all of the functions which have been annotated with the #[tailcall]
macro and transform them together, and AFAIK this is not possible.
Currently, proc macros just transform the tokens to which they are applied. Even if you were allowed to share state between invocations (something that is not officially supported), you don't have a good way to know when your "last" invocation will be to generate the shared code.
To make something like that work, I think you would need to lift the macro out to an item containing all of the functions (probably a module):
#![tailcall]
#[tailcall]
fn is_even(x: u128) -> bool {
if x > 0 {
is_odd(x - 1)
} else {
true
}
}
#[tailcall]
fn is_odd(x: u128) -> bool {
if x > 0 {
is_even(x - 1)
} else {
false
}
}
// ...or....
#[tailcall]
mod my_mod {
#[tailcall]
fn is_even(x: u128) -> bool {
if x > 0 {
is_odd(x - 1)
} else {
true
}
}
#[tailcall]
fn is_odd(x: u128) -> bool {
if x > 0 {
is_even(x - 1)
} else {
false
}
}
}
I think the final transform should be something like this.
Yeah, I don't have much to offer in terms of implementation strategies. In a lot of languages that allow mutual recursion, the user has to define the functions in a block. So maybe the example you gave with defining them in a module makes sense. I was trying to think if there is a way to use macros or something to make it clear what the grouping is about and it could create the module behind the scenes?
That final transformation version looks good to me.
Thanks, I look forward to checking this out!