HuwCampbell/grenade

Using Recurrent and Concat together

cpennington opened this issue · 10 comments

I'm trying to use Recurrent and Concat together in the same network. In particular, I'm trying to run two LSTMs in parallel against different subsets of the input, and then want to Concat the results together.

I have something like this, so far:

type R = Recurrent
type F = FeedForward

type ShapeInput = 'D1 164

type CropOpponent = Crop 0 0 0 55
type CropPlayer = Crop 0 55 0 164

type LearnPlayer = RecurrentNetwork
    '[ F Reshape
    , F CropPlayer
    , F Reshape
    , R (LSTM 109 20)
    ]
    '[ ShapeInput
    , D2 1 164
    , D2 1 109
    , D1 109
    , D1 20
    ]

type LearnOpponent = RecurrentNetwork
    '[ F Reshape
    , F CropOpponent
    , F Reshape
    , R (LSTM 55 10)
    ]
    '[ ShapeInput
    , D2 1 164
    , D2 1 55
    , D1 55
    , D1 10
    ]

type RecNet = Network
    '[ Concat
        ShapeInput
        LearnPlayer
        ShapeInput
        LearnOpponent
    ]
    '[ ShapeInput
    , 'D1 30
    ]

randomNet :: MonadRandom m => m RecNet
randomNet = randomNetwork

On compilation, the error I'm getting is:

LearningBot.hs:69:13: error:
    • Couldn't match type ‘'False’ with ‘'True’
        arising from a use of ‘randomNetwork’
    • In the expression: randomNetwork
      In an equation for ‘randomNet’: randomNet = randomNetwork

Is there any way to accomplish what I'm looking for with Grenade right now?

P.S. I also tried this way, and got the same error:


type RecNet = RecurrentNetwork
    '[ F (
        Concat
            ShapeInput
            LearnPlayer
            ShapeInput
            LearnOpponent
        )
    ]
    '[ ShapeInput
    , 'D1 30
    ]

type RecInput = RecurrentInputs
    '[ F (
        Concat
            ShapeInput
            LearnPlayer
            ShapeInput
            LearnOpponent
        )
    ]

randomNet :: MonadRandom m => m (RecNet, RecInput)
randomNet = randomRecurrent

Great question.

First up. I don't think your Crop layers are correct. The 164 doesn't seem right in CropPlayer. (The numbers are how many are taken from the left and right, not the resulting width).

Seems like I should add a version of Crop which works on 1D shapes, you can do this too if you like in your own code.

I believe right now it's probably not possible*. As really, RecNet should be a recurrent network, with an recurrent Concat layer (R instead of F).

Now the problem is that Concat isn't an instance of RecurrentLayer. I can't think of a fundamental reason that it shouldn't be, or at least that a layer just like Concat (RecConcat) couldn't exist which takes two layers which are tagged with F or R and makes a new recurrent layer.

Would you like to try writing it?

EDIT:

  • With the current Concat layer (or without an orphan instance). One can write their own layers downstream

Ah, good to know about how Crop works.

I'll take a stab at a 1D Crop and making Concat an instance of RecurrentLayer. Hopefully the types should guide me in the right direction (and I'll drop back here for advice if I get stuck).

Ahh, there's actually another problem. I haven't yet written an instance of RecurrentLayer for RecurrentNetwork.

I think it's possible, but requires packing all the recurrent (sideways travelling) shapes into a single vector.

You might have to run both LSTM networks forwards individually for now. The GAN mnist example gives a non-recurrent example of something like this.

Ah, ok. I had thought about doing that, but hadn't looked closely enough at runBackwards/runGradient to see that they spit out something input-shaped.

Seems like runNetwork for both LSTMs, then combine their output, and feed that into runNetwork for the combining network. Then take the target output, and runBackward through the combining network to get target results for the two LSTMs, and then runGradient/applyUpdate for all networks should do the trick. I'll give it a try, see how it works out.

Thanks for your help!

That's right.
Only difference is you'll need runRecurrent and backPropagateRecurrent for the LSTM nets.

Edit. Sorry:
runRecurrentForwards and runRecurrentBackwards would also be useful.

Cool, I'm making progress on this. One question that came up as I was working is whether there's an easy way to construct an all-zero vector for a particular RecurrentInput shape. I want to make sure my network is always starting from the same state at the start of every game.

You can just use the literal 0.

S is an instance of Num so has fromInteger. In fact RecurrentInputs xs is also an instance of Num, so that should work for the entire stack.

If you look at the code for backPropagateRecurrent you can see I do this (for the back propagated sideways gradients at least).

Please see #32

In that branch, this will compile

type R = Recurrent
type F = FeedForward

type ShapeInput = 'D1 10

type LearnPlayer = RecurrentNetwork
   '[ R (LSTM 10 20) ]
   '[ ShapeInput , D1 20 ]

type LearnOpponent = RecurrentNetwork
   '[ R (LSTM 10 20) ]
   '[ ShapeInput, D1 20 ]

type RecNet = RecurrentNetwork
    '[ R (
        ConcatRecurrent
          (D1 20)
          (R LearnPlayer)
          (D1 20)
          (R LearnOpponent)
        )
    ]
   '[ ShapeInput, 'D1 40 ]

randomNet :: MonadRandom m => m RecNet
randomNet = randomRecurrent

I believe this is fixed, but feel free to follow up with any problems you're having.