KingGugu/TiCoSeRec

代码疑惑

Closed this issue · 2 comments

`def get_user_seqs(data_file):
lines = open(data_file).readlines()
user_seq = []
item_set = set()
for line in lines:
user, items = line.strip().split(' ', 1)
items = items.split(' ')
items = [int(item) for item in items]
user_seq.append(items)
item_set = item_set | set(items)
max_item = max(item_set)

num_users = len(lines)
num_items = max_item + 2

valid_rating_matrix = generate_rating_matrix_valid(user_seq, num_users, num_items)
test_rating_matrix = generate_rating_matrix_test(user_seq, num_users, num_items)
return user_seq, max_item, valid_rating_matrix, test_rating_matrix

`
请问作者我在看代码时发现num_items = max_item +2,但是数据集中item的id是以1为开头的,我的理解这个num_items应该是与max_items相等的,但是代码却+2,是我忽略了什么吗?能否请作者解释一下,十分感谢

感谢你的关注和提问。这里的写法其实是为了在创建稀疏矩阵时产生bug。

以Beauty数据集为例,因为item的id是从1开始的,所以12101个item的id就是1到12101,但矩阵的索引是从0开始的,那么索引就会是0到12100,这样就会产生bug。基于此,你可以尝试把+2换成+1,程序是可以正常运行的,加别的大于1的数导致矩阵shape增加不会产生影响,如果去掉+2,就会产生ValueError: column index exceeds matrix dimensions的报错。另外,如果在处理数据集时,使item的id从0开始,那么这里就不需要+1或+2。

后面在初始化embdding时,num_items需要+1或+2也是同理。因为dataloader的getitem方法进行索引时,索引也会从0开始,如果在初始化之前不加的话,训练就会有类似的索引越界的报错。

如果还有问题,欢迎提出。

这个issue将会被关闭,如果你有其它问题,欢迎创建新的issue。