graph embedding中transitionCountMatrix的计算问题
vcicii opened this issue · 0 comments
vcicii commented
源代码:
def generateTransitionMatrix(samples):
pairSamples = samples.flatMap(lambda x: generate_pair(x))
pairCountMap = pairSamples.countByValue()
pairTotalCount = 0
transitionCountMatrix = defaultdict(dict)
itemCountMap = defaultdict(int)
for key, cnt in pairCountMap.items():
key1, key2 = key
# 此处是否应该改为 += cnt
transitionCountMatrix[key1][key2] = cnt
itemCountMap[key1] += cnt
pairTotalCount += cnt
......
修改:
for key, cnt in pairCountMap.items():
key1, key2 = key
if key1 not in transitionCountMatrix or key2 not in transitionCountMatrix[key1]:
transitionCountMatrix[key1][key2] = cnt
else:
transitionCountMatrix[key1][key2] += cnt
itemCountMap[key1] += cnt
pairTotalCount += cnt