Ach113/federeco

MF and MLP forward in models.py, and Fed Avg function

Closed this issue · 2 comments

Hi, I found your code while trying to reproduce the result from the FedNCF paper.
Your report shows that you got similar results but I could not manage to get that. In particular, I am stuck around 0.3 loss, 0.43 hit rate, and 0.25 ndcg.

I notice that there might be a problem in the forward function in models.py with MF forward and MLP forward since both of them get the the embedding from mf_embedding layer.

 # matrix factorization
mf_user_latent = torch.nn.Flatten()(self.mf_embedding_user(user_input))
mf_item_latent = torch.nn.Flatten()(self.mf_embedding_user(item_input))
mf_vector = torch.mul(mf_user_latent, mf_item_latent)
# mlp
mlp_user_latent = torch.nn.Flatten()(self.mf_embedding_user(user_input))
mlp_item_latent = torch.nn.Flatten()(self.mf_embedding_item(item_input))
mlp_vector = torch.cat([mlp_user_latent, mlp_item_latent], dim=1)

Furthermore, in the fed avg function, weights from the user embedding layer also get aggregation while they are supposed to be private.

def federated_averaging(client_weights: List[collections.OrderedDict]) -> collections.OrderedDict:
    """
    calculates the average of client weights
    """
    keys = client_weights[0].keys()
    averages = copy.deepcopy(client_weights[0])

    for w in client_weights[1:]:
        for key in keys:
            averages[key] += w[key]

    for key in keys:
        averages[key] /= len(client_weights)
    return averages

Can you address these issues for me?

(Sorry if you get any confusion, I'm not good at English)

Thanks, I will take a look at it.

I have fixed the error with the embeddings and model's performance is still the same. I am also getting the values that you have mentioned.

But I get similar results when I run original NCF repo, so I don't think model is performing poorly.

As for the aggregation of user embeddings, I am not sure if it would cause any privacy concerns. Once aggregated, it is virtually impossible to separate the data of an individual client with the rest. I am not sure how else the aggregation would be done.