Improper asynchronous update in a3c
rahulptel opened this issue · 1 comments
I doubt whether the asynchronous update made by the current a3c is adhering to what is suggested in the paper. Suppose the workers share the shared_model
. Then each worker should:
- Copy the weight of the shared network into its
local_model
- Runs for
n
steps or end of episode - Calculate gradient
- Pass the gradient of the
local_model
toshared_model
- Update the
shared_model
and go to step 1
Thus, when the local_model
is taking the n
steps, it's weights do not change.
However, in the current implementation, we directly use the shared_model
for taking those n
steps. Hence, it may happen that some process P1
updates the weights of shared_model
, which might affect the process P2
. P2
might have started with some weight configuration of shared_model
, which are now modified before those n
steps are completed.
I think we can make the following change to the train
method to avoid the above phenomenon:
def train(model):
local_model = ActorCritic()
local_model.load_state_dict(model.state_dict())
# Create optimizer for the shared model
# Create environment
# Take n steps using local_model
optimizer.zero_grad()
# Calculate loss and get the gradients
loss_fn(local_model(data), labels).backward()
for param, shared_param in zip(local_model.parameters(), model.parameters()):
if shared_param.grad is not None:
shared_param._grad = param.grad
optimizer.step()
I am not very much familiar with the asynchronous model update in Pytorch but looking at the docs at https://pytorch.org/docs/stable/notes/multiprocessing.html#asynchronous-multiprocess-training-e-g-hogwild, I think we are using the shared_model
all the time.
If you think what I say is correct, I can make a PR.
I think you made another proper claim.
As you noted, the weights of local_model
changes in the current version of the implementation.
It should remain unchanged until the local model interacts with the environment for n_steps
.
If you please revise the issue and make PR, I would be very thankful.
Thank you again.