共计 5762 个字符,预计需要花费 15 分钟才能阅读完成。
先回顾一下之前写到哪里了
- 介绍ps-lite的基本概念 https://www.deeplearn.me/4302.html
- 介绍ps-lite核心组成 postOffice https://www.deeplearn.me/4303.html
- 介绍ps-lite 通信模块van https://www.deeplearn.me/4306.html
- 介绍ps-lite 中介 customer https://www.deeplearn.me/4308.html
这篇文章主要讲一下server 和woker,在扒拉一下ps架构的一张图
一般意义上来说:
- server负责梯度和参数的更新
- woker端负责前向和后向的计算
这也是之前有customer出现的缘故,server和worker集中去计算,负责通信的任务就交给customer。在上一节讲customer在哪里被创建的时候就提到kvworker和kvserver,这里在着重讲一下吧!
在这之前还是要补充一点kvworker 和kvserver都继承 SimpleApp,那么SimpleApp 又是啥?
SimpleApp:KVServer和KVWorker的父类,它提供了简单的Request, Wait, Response,Process功能;KVServer和KVWorker分别根据自己的使命重写了这些功能;
kvwoker
构造函数
explicit KVWorker(int app_id, int customer_id) : SimpleApp() {
using namespace std::placeholders;
slicer_ = std::bind(&KVWorker<Val>::DefaultSlicer, this, _1, _2, _3);
obj_ = new Customer(app_id, customer_id, std::bind(&KVWorker<Val>::Process, this, _1));
}
这里关于构造函数的定义也在上一节提到了,此处略过哈!
PULL函数
从开始的图你也看到worker需要从server拉取参数数据,那么肯定需要pull。
int Pull(const std::vector<Key>& keys,
std::vector<Val>* vals,
std::vector<int>* lens = nullptr,
int cmd = 0,
const Callback& cb = nullptr,
int priority = 0) {
SArray<Key> skeys(keys);
int ts = AddPullCB(skeys, vals, lens, cmd, cb);
KVPairs<Val> kvs;
kvs.keys = skeys;
kvs.priority = priority;
Send(ts, false, true, cmd, kvs);
return ts;
}
这里面有两个需要关注的调用,AddPullCB 和 Send,依次来看下这两个函数的定义和功能
AddPullCB 是添加一个callback,这个callback等所有server返回结果之后在执行,可以认为是一个阻塞等操作。
int KVWorker<Val>::AddPullCB(
// C* vals和D* lens指向由调用者指定的结构体。
// 等所有server都返回后,从所有server拉来的数据
const SArray<Key>& keys, C* vals, D* lens, int cmd,
// Callback& cb代表在所有server回复后要执行的额外的回调
// 一般我们都是在pull后就立刻阻塞等待,所以cb一般为空
const Callback& cb) {
// ************** 创建request,返回的ts是该request_id
int ts = obj_->NewRequest(kServerGroup);
// ************** 添加callback,等所有server都回复后再执行
AddCallback(ts, [this, ts, keys, vals, lens, cb]() mutable {
......
// 容纳ts(即request_id)所接受数据的缓冲区
auto& kvs = recv_kvs_[ts];
......
// total_keys是根据kvs统计出来的接收到的key的总数
// keys是当初请求的所有keys,检查二者是否相等
......
CHECK_EQ(total_key, keys.size()) << "lost some servers?";
// ************** 将所有server返回的数据,合并,填充到用户指定的输出位置
// vals和lens都指向调用者传入的结构体
// p_vals和p_lens都是指向输出区的指针
Val* p_vals = vals->data();
......
p_lens = lens->data();
......
// 遍历从各台server接收到的内容,填充到输出区p_vals和p_lens
for (const auto& s : kvs) {
memcpy(p_vals, s.vals.data(), s.vals.size() * sizeof(Val));
p_vals += s.vals.size();
if (p_lens) {
memcpy(p_lens, s.lens.data(), s.lens.size() * sizeof(int));
p_lens += s.lens.size();
}
}
......
recv_kvs_.erase(ts);//清空本次请求的接收缓冲区
......
if (cb) cb();// 如果有额外的callback,执行之
});
return ts;
}
send的操作才是真正的去请求server,下面看下send的定义
void KVWorker<Val>::Send(int timestamp, bool push, bool pull, int cmd, const KVPairs<Val>& kvs) {
// ****************** 决定要向哪些server发送请求
SlicedKVs sliced;// 存储分配结果
slicer_(kvs, Postoffice::Get()->GetServerKeyRanges(), &sliced);
// ****************** 有些server不包含本次请求要求的keys,提前处理
int skipped = 0;// 本次请求不涉及的servers的总数
//这里调用first参数,需要去追溯一下SlicedKVs 的定义
// using SlicedKVs = std::vector<std::pair<bool, KVPairs<Val>>>;
// bool 参数决定是否需要去这个server节点拉取数据,不需要直接跳过
for (size_t i = 0; i < sliced.size(); ++i) {
if (!sliced[i].first) ++skipped;
}
// 内部不过是tracker_[timestamp].second += skipped
// 假设这些不涉及的servers已经返回了
obj_->AddResponse(timestamp, skipped);
......
// ****************** 向所有涉及到的server发送请求
for (size_t i = 0; i < sliced.size(); ++i) {
const auto& s = sliced[i];
if (!s.first) continue;//本次请求不需要访问的server节点直接跳过
Message msg;
msg.meta.app_id = obj_->app_id();
msg.meta.customer_id = obj_->customer_id();
msg.meta.request = true;
msg.meta.push = push;
msg.meta.pull = pull;
msg.meta.head = cmd;
msg.meta.timestamp = timestamp;
msg.meta.recver = Postoffice::Get()->ServerRankToID(i);
msg.meta.priority = kvs.priority;
const auto& kvs = s.second;//分配到当前节点上的key-value pairs
if (kvs.keys.size()) {
msg.AddData(kvs.keys);
msg.AddData(kvs.vals);
if (kvs.lens.size()) {
msg.AddData(kvs.lens);
}
}
//通过van通信模块发送请求
Postoffice::Get()->van()->Send(msg);
}
}
至此再回去看pull 应该就差不多了,除了pull之外还有一个zpull ,全称是zero pull,说是实现了零拷贝,起到一个加速的作用,这里就不细看了。
PUSH
说完pull 就是push了,woker的push 就是要把梯度传给server,让server 去更新参数。
int ZPush(const SArray<Key>& keys,
const SArray<Val>& vals,
const SArray<int>& lens = {},
int cmd = 0,
const Callback& cb = nullptr,
int priority = 0) {
int ts = obj_->NewRequest(kServerGroup);
AddCallback(ts, cb);
KVPairs<Val> kvs;
kvs.keys = keys;
kvs.vals = vals;
kvs.lens = lens;
kvs.priority = priority;
// send 将这些梯度传递到指定的server上
Send(ts, true, false, cmd, kvs);
return ts;
}
同时也还有一个zpush,本质上实现的功能是一致的。
差不多 woker 就这些事情,接下来讲下server ,其实都差不多,因为只是各自干的事情内容又一点不一样而已。
kvserver
构造函数
explicit KVServer(int app_id) : SimpleApp() {
using namespace std::placeholders;
obj_ = new Customer(app_id, app_id, std::bind(&KVServer<Val>::Process, this, _1));
}
Server 主要是处理参数更新和数据查询
- 参数更新:根据梯度更新相应的神经网络参数
- 数据查询:worker需要拉取参数去执行前向传播
完成上述需求主要依靠两个函数
Process
这个主要是来处理woker push 过来的数据
template <typename Val>
void KVServer<Val>::Process(const Message& msg) {
if (msg.meta.simple_app) {
SimpleApp::Process(msg); return;
}
KVMeta meta;
meta.cmd = msg.meta.head;
meta.push = msg.meta.push;
meta.pull = msg.meta.pull;
meta.sender = msg.meta.sender;
meta.timestamp = msg.meta.timestamp;
meta.customer_id = msg.meta.customer_id;
//KVPairs 保存的就是传递的数据
KVPairs<Val> data;
int n = msg.data.size();
if (n) {
CHECK_GE(n, 2);
data.keys = msg.data[0];
data.vals = msg.data[1];
if (n > 2) {
CHECK_EQ(n, 3);
data.lens = msg.data[2];
CHECK_EQ(data.lens.size(), data.keys.size());
}
}
CHECK(request_handle_);
//这个request_handle_是用户自定义的处理逻辑函数,主要是梯度更新参数的规则等
request_handle_(meta, data, this);
}
这里给出test里面的一个实例
void StartServer() {
if (!IsServer()) return;
auto server = new KVServer<float>(0);
//这一步就是在设置 request_handle_
server->set_request_handle(KVServerDefaultHandle<float>());
RegisterExitCallback([server](){ delete server; });
}
Response
故名思义就是将数据回复给worker,好像没啥要讲的。。。
template <typename Val>
void KVServer<Val>::Response(const KVMeta& req, const KVPairs<Val>& res) {
//res里存储的就是worker需要数据,这里只是在包装以 Message 封装一下,最后在通过send回复给worker
Message msg;
msg.meta.app_id = obj_->app_id();
msg.meta.customer_id = req.customer_id;
msg.meta.request = false;
msg.meta.push = req.push;
msg.meta.pull = req.pull;
msg.meta.head = req.cmd;
msg.meta.timestamp = req.timestamp;
msg.meta.recver = req.sender;
if (res.keys.size()) {
msg.AddData(res.keys);
msg.AddData(res.vals);
if (res.lens.size()) {
msg.AddData(res.lens);
}
}
Postoffice::Get()->van()->Send(msg);
}