共计 862 个字符,预计需要花费 3 分钟才能阅读完成。
修改记录:
2017/3/22修改代码中的部分BUG
核心代码 https://github.com/zhusimaji/ml/blob/master/prank.py
def learn_to_rank(self):
print 'start to learn rank'
new_label = [0 for x in range(self.rank_label)]
tao = []
self.weight = [0.0 for x in range(self.rank_cate)]
for num in range(self.rank_iter):
for i in tqdm(range(self.rank_num)):
predict_rank = 0
sumwx = sum([self.weight[x] * self.source_data[i][x + 2]
for x in range(len(self.weight))])
# 预测排名
for r in range(self.rank_label):
if sumwx - self.br[r] < 0:
predict_rank = r
break
# 获取真实label
if self.source_data[i][0] != predict_rank:
for r in range(self.rank_label):
if self.source_data[i][0] - r < 0:
new_label[r] = -1
else:
new_label[r] = 1
tao = [new_label[x] if (
sumwx - self.br[x]) * new_label[x] <= 0 else 0.0 for x in range(self.rank_label)]
tao_sum = sum(tao)
new_weight = [self.weight[x] + tao_sum * self.source_data[i][x+2]
for x in range(self.rank_cate)]
self.weight = new_weight
for r in range(self.rank_label):
self.br[r] = self.br[r] - tao[r]
正文完
请博主喝杯咖啡吧!