chapter8 relu2deriv
leeivan1007 opened this issue · 3 comments
i modify the code, isolate the relu in forward and back propagation
the performance look like better
for i in range(len(images)):
layer_0 = images[i:i+1]
layer_1 = np.dot(layer_0,weights_0_1) # modify
layer_1_relu = relu(layer_1) # modify
layer_2 = np.dot(layer_1_relu,weights_1_2)
error += np.sum((labels[i:i+1] - layer_2) ** 2)
correct_cnt += int(np.argmax(layer_2) == \
np.argmax(labels[i:i+1]))
layer_2_delta = (labels[i:i+1] - layer_2)
layer_1_delta = layer_2_delta.dot(weights_1_2.T)\
* relu2deriv(layer_1)
weights_1_2 += alpha * layer_1_relu.T.dot(layer_2_delta) # modify
weights_0_1 += alpha * layer_0.T.dot(layer_1_delta)
before
after
the output of reu2deriv(layer_1) will be equal to relu2deriv(layer_1_relu) in your code because whereever layer_1 > 0 the layer_1_relu >0 hence they produce the same boolean output.
The objective of layer_1_delta multiplying with relu is that we are making those delta's 0 whose inputs did not contribute to the delta at layer_2,which in this case occurs when the outputs of layer_1 <0 .So it should be immaterial whether you use layer_1 or layer_1_relu as input to the relu2deriv function .
Check this out by just adding this line:
`
assert(relu2deriv(layer_1)==relu2deriv(layer_1_relu)))
`
I think there is an error in the relu2deriv function in chapter 8
relu2deriv = lambda x: x>=0 # returns 1 for input > 0, return 0 otherwise
relu = lambda x:(x>=0) * x # returns x if x > 0, return 0 otherwise
relu2deriv(1) # True = 1 - ok
relu(1) # 1 - ok
relu(0) # 0 - weight does not affect the output
relu2deriv(0) # True = 1 - but used in backpropagation
relu2deriv(-1) # False = 0 - ok
relu(-1) # 0 - ok
returns 1 for x > 0, return 0 otherwise
But in the function: x>=0
The correct code is shown in chapter 6: Backpropagation in Code
relu2deriv = lambda x: x>0 # returns 1 for input > 0, return 0 otherwise
relu = lambda x:(x>0) * x # returns x if x > 0, return 0 otherwise
relu2deriv(1) # True = 1 - ok
relu(1) # 1 - ok
relu(0) # 0 - ok
relu2deriv(0) # False = 0 - ok
relu2deriv(-1) # False = 0 - ok
relu(-1) # 0 - ok