loopholelabs/frisbee-go

Packet Decoder Breaks with Multiple Embedded Structs

Closed this issue · 0 comments

There is a very weird and very annoying bug in the generated RPC code for decoding an encoding structure.

If you look at the following code:

func (x *APIKeyCertificateProvider) decode(d *packet.Decoder) error {
	if d.Nil() {
		return nil
	}
	var err error
	x.error, err = d.Error()
	if err != nil {
		x.ignore, err = d.Bool()
		if err != nil {
			return err
		}
		x.APIKey, err = d.String()
		if err != nil {
			return err
		}
		x.Owner, err = d.String()
		if err != nil {
			return err
		}
		if x.CertificateProvider == nil {
			x.CertificateProvider = NewCertificateProvider()
		}
		err = x.CertificateProvider.decode(d)
		if err != nil {
			return err
		}
		if x.ACMEUser == nil {
			x.ACMEUser = NewACMEUser()
		}
		err = x.ACMEUser.decode(d)
		if err != nil {
			return err
		}
	}
	d.Return()
	return nil
}```

You'll see that at one point we call the line `err = x.CertificateProvider.decode(d)`. Digging into this function, we have the following:
```go
func (x *CertificateProvider) decode(d *packet.Decoder) error {
	if d.Nil() {
		return nil
	}
	var err error
	x.error, err = d.Error()
	if err != nil {
		x.ignore, err = d.Bool()
		if err != nil {
			return err
		}
		x.Identifier, err = d.String()
		if err != nil {
			return err
		}
		if x.Metadata == nil {
			x.Metadata = NewMetadata()
		}
		err = x.Metadata.decode(d)
		if err != nil {
			return err
		}
	}
	d.Return()
	return nil
}

If you look carefully, you'll see that the CertificateProvider.decode function (the one right above), calls d.Return before returning.

d.Return, of course, returns the packet.Decoder back to the packet pool. It also resets the byte slice:

func ReturnDecoder(d *Decoder) {
	if d != nil {
		d.b = nil
		decoderPool.Put(d)
	}
}

func (d *Decoder) Return() {
	ReturnDecoder(d)
}

So, when the APIKeyCertificateProvider.decode function is called, it calls the CertificateProvider.decode function, which then calls d.Return before returning back to the APIKeyCertificateProvider.decode call - but that function is not complete, and still needs to use d. When it tries to do so, it occasionally fails.

This is why this bug was not caught before - it only fails occasionally and only when there are multiple embedded structs.

To fix this bug, my suggestion is that we modify the decode template file:

{{define "decode"}}
func (x *{{ CamelCase .FullName }}) Decode (b []byte) error {
    if x == nil {
        return NilDecode
    }
    d := packet.GetDecoder(b)
    return x.decode(d)
}
{{end}}

To look something like:

{{define "decode"}}
func (x *{{ CamelCase .FullName }}) Decode (b []byte) error {
    if x == nil {
        return NilDecode
    }
    d := packet.GetDecoder(b)
    defer d.Return()
    return x.decode(d)
}
{{end}}

And we modify the internalDecode template:

{define "internalDecode"}}
func (x *{{CamelCase .FullName}}) decode(d *packet.Decoder) error {
    if d.Nil() {
        return nil
    }
    var err error
    x.error, err = d.Error()
    if err != nil {
        x.ignore, err = d.Bool()
        if err != nil {
            return err
        }
        {{ $decoding := GetDecodingFields .Fields -}}
        {{ range $field := $decoding.Other -}}
            {{ $decoder := GetLUTDecoder $field.Kind -}}
            {{ if eq $field.Kind 12 -}} {{/* protoreflect.BytesKind */ -}}
                x.{{ CamelCaseName $field.Name }}, err = d{{ $decoder }}(nil)
            {{ else if eq $field.Kind 14 -}}  {{/* protoreflect.EnumKind */ -}}
                var {{ CamelCaseName $field.Name }}Temp uint32
                {{ CamelCaseName $field.Name }}Temp, err = d{{ $decoder }}()
                x.{{ CamelCaseName $field.Name }} = {{ FindValue $field }}({{ CamelCaseName $field.Name }}Temp)
            {{ else -}}
                x.{{ CamelCaseName $field.Name }}, err = d{{ $decoder }}()
            {{end -}}
            if err != nil {
                return err
            }
        {{end -}}

        {{ if $decoding.SliceFields -}}
            var sliceSize uint32
        {{end -}}
        {{ range $field := $decoding.SliceFields -}}
        {{ $kind := GetKind $field.Kind -}}
        sliceSize, err = d.Slice({{ $kind }})
        if err != nil {
            return err
        }
        if uint32(len(x.{{ CamelCaseName $field.Name }})) != sliceSize {
            x.{{ CamelCaseName $field.Name }} = make({{ FindValue $field }}, sliceSize)
        }
        for i := uint32(0); i < sliceSize; i++ {
            {{ $decoder := GetLUTDecoder $field.Kind -}}
            {{ if eq $field.Kind 11 -}} {{/* protoreflect.MessageKind */ -}}
            err = x.{{ CamelCaseName $field.Name }}[i].decode(d)
            {{ else -}}
            x.{{ CamelCaseName $field.Name }}[i], err = d{{ $decoder }}()
            {{end -}}
            if err != nil {
                return err
            }
        }
        {{end -}}
        {{ range $field := $decoding.MessageFields -}}
            {{ if $field.IsMap -}}
                if !d.Nil() {
                {{ $keyKind := GetKind $field.MapKey.Kind -}}
                {{ $valKind := GetKind $field.MapValue.Kind -}}

                {{ CamelCaseName $field.Name }}Size, err := d.Map({{ $keyKind }}, {{ $valKind }})
                if err != nil {
                return err
                }
                x.{{ CamelCaseName $field.Name }} = New{{ CamelCase $field.FullName }}Map({{ CamelCaseName $field.Name }}Size)
                err = x.{{ CamelCaseName $field.Name }}.decode(d, {{ CamelCaseName $field.Name }}Size)
                if err != nil {
                return err
                }
                }
            {{ else -}}
                if x.{{ CamelCaseName $field.Name }} == nil {
                x.{{ CamelCaseName $field.Name }} = New{{ CamelCase $field.Message.FullName }}()
                }
                err = x.{{ CamelCaseName $field.Name }}.decode(d)
                if err != nil {
                return err
                }
            {{end -}}
        {{end -}}
    }
    d.Return()
    return nil
}
{{end}}

To look something like:

{define "internalDecode"}}
func (x *{{CamelCase .FullName}}) decode(d *packet.Decoder) error {
    if d.Nil() {
        return nil
    }
    var err error
    x.error, err = d.Error()
    if err != nil {
        x.ignore, err = d.Bool()
        if err != nil {
            return err
        }
        {{ $decoding := GetDecodingFields .Fields -}}
        {{ range $field := $decoding.Other -}}
            {{ $decoder := GetLUTDecoder $field.Kind -}}
            {{ if eq $field.Kind 12 -}} {{/* protoreflect.BytesKind */ -}}
                x.{{ CamelCaseName $field.Name }}, err = d{{ $decoder }}(nil)
            {{ else if eq $field.Kind 14 -}}  {{/* protoreflect.EnumKind */ -}}
                var {{ CamelCaseName $field.Name }}Temp uint32
                {{ CamelCaseName $field.Name }}Temp, err = d{{ $decoder }}()
                x.{{ CamelCaseName $field.Name }} = {{ FindValue $field }}({{ CamelCaseName $field.Name }}Temp)
            {{ else -}}
                x.{{ CamelCaseName $field.Name }}, err = d{{ $decoder }}()
            {{end -}}
            if err != nil {
                return err
            }
        {{end -}}

        {{ if $decoding.SliceFields -}}
            var sliceSize uint32
        {{end -}}
        {{ range $field := $decoding.SliceFields -}}
        {{ $kind := GetKind $field.Kind -}}
        sliceSize, err = d.Slice({{ $kind }})
        if err != nil {
            return err
        }
        if uint32(len(x.{{ CamelCaseName $field.Name }})) != sliceSize {
            x.{{ CamelCaseName $field.Name }} = make({{ FindValue $field }}, sliceSize)
        }
        for i := uint32(0); i < sliceSize; i++ {
            {{ $decoder := GetLUTDecoder $field.Kind -}}
            {{ if eq $field.Kind 11 -}} {{/* protoreflect.MessageKind */ -}}
            err = x.{{ CamelCaseName $field.Name }}[i].decode(d)
            {{ else -}}
            x.{{ CamelCaseName $field.Name }}[i], err = d{{ $decoder }}()
            {{end -}}
            if err != nil {
                return err
            }
        }
        {{end -}}
        {{ range $field := $decoding.MessageFields -}}
            {{ if $field.IsMap -}}
                if !d.Nil() {
                {{ $keyKind := GetKind $field.MapKey.Kind -}}
                {{ $valKind := GetKind $field.MapValue.Kind -}}

                {{ CamelCaseName $field.Name }}Size, err := d.Map({{ $keyKind }}, {{ $valKind }})
                if err != nil {
                return err
                }
                x.{{ CamelCaseName $field.Name }} = New{{ CamelCase $field.FullName }}Map({{ CamelCaseName $field.Name }}Size)
                err = x.{{ CamelCaseName $field.Name }}.decode(d, {{ CamelCaseName $field.Name }}Size)
                if err != nil {
                return err
                }
                }
            {{ else -}}
                if x.{{ CamelCaseName $field.Name }} == nil {
                x.{{ CamelCaseName $field.Name }} = New{{ CamelCase $field.Message.FullName }}()
                }
                err = x.{{ CamelCaseName $field.Name }}.decode(d)
                if err != nil {
                return err
                }
            {{end -}}
        {{end -}}
    }
    return nil
}
{{end}}

This will make sure the Decode function is the one that returns the decoder back to the pool - and that will only happen after all nested execution is complete.