代码疑惑
Closed this issue · 2 comments
`def get_user_seqs(data_file):
lines = open(data_file).readlines()
user_seq = []
item_set = set()
for line in lines:
user, items = line.strip().split(' ', 1)
items = items.split(' ')
items = [int(item) for item in items]
user_seq.append(items)
item_set = item_set | set(items)
max_item = max(item_set)
num_users = len(lines)
num_items = max_item + 2
valid_rating_matrix = generate_rating_matrix_valid(user_seq, num_users, num_items)
test_rating_matrix = generate_rating_matrix_test(user_seq, num_users, num_items)
return user_seq, max_item, valid_rating_matrix, test_rating_matrix
`
请问作者我在看代码时发现num_items = max_item +2,但是数据集中item的id是以1为开头的,我的理解这个num_items应该是与max_items相等的,但是代码却+2,是我忽略了什么吗?能否请作者解释一下,十分感谢
感谢你的关注和提问。这里的写法其实是为了在创建稀疏矩阵时产生bug。
以Beauty数据集为例,因为item的id是从1开始的,所以12101个item的id就是1到12101,但矩阵的索引是从0开始的,那么索引就会是0到12100,这样就会产生bug。基于此,你可以尝试把+2换成+1,程序是可以正常运行的,加别的大于1的数导致矩阵shape增加不会产生影响,如果去掉+2,就会产生ValueError: column index exceeds matrix dimensions的报错。另外,如果在处理数据集时,使item的id从0开始,那么这里就不需要+1或+2。
后面在初始化embdding时,num_items需要+1或+2也是同理。因为dataloader的getitem方法进行索引时,索引也会从0开始,如果在初始化之前不加的话,训练就会有类似的索引越界的报错。
如果还有问题,欢迎提出。
这个issue将会被关闭,如果你有其它问题,欢迎创建新的issue。