共计 12731 个字符,预计需要花费 32 分钟才能阅读完成。
晚上也在思考这篇文章该以怎样的结构来撰写,ps-lite虽然是轻量版,但是很多东西都是在复用,东西都混在一起,所以在解释理清思路的时候有点绕。思来想去还是按照ps程序启动的逻辑来讲,涉及到本文需要讲的点在重点描写。
ps启动
启动一个完整的 ps 服务,需要启动scheduler、sever和worker,启动的时候都会产生一个实际的物理进程,这个进程里都会包含PostOffice,负责管理全局信息。
ps-lite给出一个简单的demo,启动的顺序是先启动schedule,然后是server 、worker
#!/bin/bash
# set -x
if [ # -lt 3 ]; then
echo "usage:0 num_servers num_workers bin [args..]"
exit -1;
fi
export DMLC_NUM_SERVER=1
shift
export DMLC_NUM_WORKER=1
shift
bin=1
shift
arg="@"
# start the scheduler
export DMLC_PS_ROOT_URI='127.0.0.1'
export DMLC_PS_ROOT_PORT=8000
export DMLC_ROLE='scheduler'
{bin}{arg} &
# start servers
export DMLC_ROLE='server'
for ((i=0; i<{DMLC_NUM_SERVER}; ++i)); do
export HEAPPROFILE=./S{i}
{bin}{arg} &
done
# start workers
export DMLC_ROLE='worker'
for ((i=0; i<{DMLC_NUM_WORKER}; ++i)); do
export HEAPPROFILE=./W{i}
{bin}{arg} &
done
从上面的启动程序可以看到,ps-lite启动任务一些参数是来自环境变量的,你会看到shell脚本里充斥着export。
那我们现在看看它是怎么启动程序的?看一个demo
#include <cmath>
#include "ps/ps.h"
using namespace ps;
void StartServer() {
if (!IsServer()) {
return;
}
auto server = new KVServer<float>(0);
server->set_request_handle(KVServerDefaultHandle<float>());
RegisterExitCallback([server](){ delete server; });
}
void RunWorker() {
if (!IsWorker()) return;
KVWorker<float> kv(0, 0);
// init
int num = 10000;
std::vector<Key> keys(num);
std::vector<float> vals(num);
int rank = MyRank();
srand(rank + 7);
for (int i = 0; i < num; ++i) {
keys[i] = kMaxKey / num * i + rank;
vals[i] = (rand() % 1000);
}
// push
int repeat = 50;
std::vector<int> ts;
for (int i = 0; i < repeat; ++i) {
ts.push_back(kv.Push(keys, vals));
// to avoid too frequency push, which leads huge memory usage
if (i > 10) kv.Wait(ts[ts.size()-10]);
}
for (int t : ts) kv.Wait(t);
// pull
std::vector<float> rets;
kv.Wait(kv.Pull(keys, &rets));
// pushpull
std::vector<float> outs;
for (int i = 0; i < repeat; ++i) {
// PushPull on the same keys should be called serially
kv.Wait(kv.PushPull(keys, vals, &outs));
}
float res = 0;
float res2 = 0;
for (int i = 0; i < num; ++i) {
res += std::fabs(rets[i] - vals[i] * repeat);
res2 += std::fabs(outs[i] - vals[i] * 2 * repeat);
}
CHECK_LT(res / repeat, 1e-5);
CHECK_LT(res2 / (2 * repeat), 1e-5);
LL << "error: " << res / repeat << ", " << res2 / (2 * repeat);
}
int main(int argc, char *argv[]) {
// start system ,这里会根据实际的角色名执行相应的逻辑,比如一开始执行是scheduler的启动任务,
// 会以scheduler的角色启动,后面的StartServer也是不会执行的,当然这里会在是server角色的时候还会在执行一次
Start(0);
// setup server nodes
StartServer();
// run worker nodes
RunWorker();
// stop system
Finalize(0, true);
return 0;
}
ps-lite在这里其实共用了一套代码,也就是说你启动scheduler、sever和worker 这些都会走一套启动代码,根据不同的角色名称去执行相应的代码逻辑,比如只有在角色scheduler的时候才会触发scheduler相关的代码逻辑。
PostOffice启动
接下来就先以scheduler 启动来介绍
Start(0);
//实际调用的方法
inline void Start(int customer_id, const char* argv0 = nullptr) {
Postoffice::Get()->Start(customer_id, argv0, true);
}
这里会有一个
Postoffice::Get()
调用Get方法是去获取PostOffice全局单例对象,这里想要强调的一点就是PostOffice是单例,即一个进程内只有这一个对象,全局变量。
无论scheduler还是worker 都会调用,那么你可以理解这里PostOffice单例是相对而言的,scheduler进程下有一个,sever进程下也有一个。
接下来再来看看 Start 函数做了哪些事情?
void Postoffice::Start(int customer_id, const char* argv0, const bool do_barrier) {
start_mu_.lock();
if (init_stage_ == 0) {
// 初始化环境变量,主要是从 shell 执行脚本中获取相关参数比如 role 角色变量、server 数量和worker 数量
InitEnvironment();
// init glog
if (argv0) {
dmlc::InitLogging(argv0);
} else {
dmlc::InitLogging("ps-lite\0");
}
// init node info.
for (int i = 0; i < num_workers_; ++i) {
int id = WorkerRankToID(i);
for (int g : {id, kWorkerGroup, kWorkerGroup + kServerGroup,
kWorkerGroup + kScheduler,
kWorkerGroup + kServerGroup + kScheduler}) {
node_ids_[g].push_back(id);
}
}
for (int i = 0; i < num_servers_; ++i) {
int id = ServerRankToID(i);
for (int g : {id, kServerGroup, kWorkerGroup + kServerGroup,
kServerGroup + kScheduler,
kWorkerGroup + kServerGroup + kScheduler}) {
node_ids_[g].push_back(id);
}
}
for (int g : {kScheduler, kScheduler + kServerGroup + kWorkerGroup,
kScheduler + kWorkerGroup, kScheduler + kServerGroup}) {
node_ids_[g].push_back(kScheduler);
}
init_stage_++;
}
start_mu_.unlock();
// start van
van_->Start(customer_id);
start_mu_.lock();
if (init_stage_ == 1) {
// record start time
start_time_ = time(NULL);
init_stage_++;
}
start_mu_.unlock();
// do a barrier here
if (do_barrier) Barrier(customer_id, kWorkerGroup + kServerGroup + kScheduler);
}
现在来看看这个初始化环境的函数做了哪些事情?
void Postoffice::InitEnvironment() {
const char* val = NULL;
std::string van_type = GetEnv("DMLC_PS_VAN_TYPE", "zmq");
// 核心的一个点就是创建了 Van,至于Van 是什么后续也会做相应的详细介绍
van_ = Van::Create(van_type);
//接下来都是在解析环境变量,对于我们而言就是判断这次启动的是哪个?
val = CHECK_NOTNULL(Environment::Get()->find("DMLC_NUM_WORKER"));
num_workers_ = atoi(val);
val = CHECK_NOTNULL(Environment::Get()->find("DMLC_NUM_SERVER"));
num_servers_ = atoi(val);
val = CHECK_NOTNULL(Environment::Get()->find("DMLC_ROLE"));
std::string role(val);
is_worker_ = role == "worker";
is_server_ = role == "server";
is_scheduler_ = role == "scheduler";
verbose_ = GetEnv("PS_VERBOSE", 0);
}
ok,让我们再回到PostOffice的start函数里,假设我们启动的是scheduler角色下的程序
for (int g : {kScheduler, kScheduler + kServerGroup + kWorkerGroup,
kScheduler + kWorkerGroup, kScheduler + kServerGroup}) {
node_ids_[g].push_back(kScheduler);
}
init_stage_++;
那这个node_ids_ 是干啥的?这里就需要讨论下 node_id 的概念
Node管理
其实是可以分为两个部分:node group 和 single node_id
首先我们介绍下 node id 映射功能,就是如何在逻辑节点和物理节点之间做映射,如何把物理节点划分成各个逻辑组,如何用简便的方法做到给组内物理节点统一发消息。
- 1,2,4分别标识Scheduler, ServerGroup, WorkerGroup。
- SingleWorker:rank * 2 + 9;SingleServer:rank * 2 + 8。
- 任意一组节点都可以用单个id标识,等于所有id之和。
概念
- Rank 是一个逻辑概念,是每一个节点(scheduler,work,server)内部的唯一逻辑标示。
- Node id 是物理节点的唯一标识,可以和一个 host + port 的二元组唯一对应。
- Node Group 是一个逻辑概念,每一个 group 可以包含多个 node id。ps-lite 一共有三组 group : scheduler 组,server 组,worker 组。
- Node group id 是 是节点组的唯一标示。
- ps-lite 使用 1,2,4 这三个数字分别标识 Scheduler,ServerGroup,WorkerGroup。每一个数字都代表着一组节点,等于所有该类型节点 id 之和。比如 2 就代表server 组,就是所有 server node 的组合。
- 为什么选择这三个数字?因为在二进制下这三个数值分别是 “001, 010, 100″,这样如果想给多个 group 发消息,直接把 几个 node group id 做 或操作 就行。
- 即 1-7 内任意一个数字都代表的是Scheduler / ServerGroup / WorkerGroup的某一种组合。
- 如果想把某一个请求发送给所有的 worker node,把请求目标节点 id 设置为 4 即可。
- 假设某一个 worker 希望向所有的 server 节点 和 scheduler 节点同时发送请求,则只要把请求目标节点的 id 设置为 3 即可,因为 3 = 2 + 1 = kServerGroup + kScheduler。
- 如果想给所有节点发送消息,则设置为 7 即可。
逻辑组的实现
三个逻辑组的定义如下:
/** \brief node ID for the scheduler */
static const int kScheduler = 1;
/**
* \brief the server node group ID
*
* group id can be combined:
* - kServerGroup + kScheduler means all server nodes and the scheuduler
* - kServerGroup + kWorkerGroup means all server and worker nodes
*/
static const int kServerGroup = 2;
/** \brief the worker node group ID */
static const int kWorkerGroup = 4;
for (int i = 0; i < num_workers_; ++i) {
int id = WorkerRankToID(i);
for (int g : {id, kWorkerGroup, kWorkerGroup + kServerGroup,
kWorkerGroup + kScheduler,
kWorkerGroup + kServerGroup + kScheduler}) {
node_ids_[g].push_back(id);
}
如下面代码所示,如果配置了 3 个worker,则 worker 的 rank 从 0 ~ 2,那么这几个 worker 实际对应的 物理 node ID 就会使用 WorkerRankToID 来计算出来。
node id 是物理节点的唯一标示,rank 是每一个逻辑概念(scheduler,work,server)内部的唯一标示。这两个标示由一个算法来确定。
Rank vs node id
node id 是物理节点的唯一标示,rank 是每一个逻辑概念(scheduler,work,server)内部的唯一标示。这两个标示由一个算法来确定。
如下面代码所示,如果配置了 3 个worker,则 worker 的 rank 从 0 ~ 2,那么这几个 worker 实际对应的 物理 node ID 就会使用 WorkerRankToID 来计算出来。
for (int i = 0; i < num_workers_; ++i) {
int id = WorkerRankToID(i);
for (int g : {id, kWorkerGroup, kWorkerGroup + kServerGroup,
kWorkerGroup + kScheduler,
kWorkerGroup + kServerGroup + kScheduler}) {
node_ids_[g].push_back(id);
}
}
具体计算规则如下:
/**
* \brief convert from a worker rank into a node id
* \param rank the worker rank
*/
static inline int WorkerRankToID(int rank) {
return rank * 2 + 9;
}
/**
* \brief convert from a server rank into a node id
* \param rank the server rank
*/
static inline int ServerRankToID(int rank) {
return rank * 2 + 8;
}
/**
* \brief convert from a node id into a server or worker rank
* \param id the node id
*/
static inline int IDtoRank(int id) {
#ifdef _MSC_VER
#undef max
#endif
return std::max((id - 8) / 2, 0);
}
- SingleWorker:rank * 2 + 9;
- SingleServer:rank * 2 + 8;
而且这个算法保证server id为偶数,node id为奇数。
这样我们可以知道,1-7 的id表示的是node group,单个节点的id 就从 8 开始。
具体计算规则如下:
Group vs node
因为有时请求要发送给多个节点,所以ps-lite用了一个 map 来存储每个 node group / single node 对应的实际的node节点集合,即 确定每个id值对应的节点id集。
std::unordered_map<int, std::vector<int>> node_ids_
for (int i = 0; i < num_workers_; ++i) {
int id = WorkerRankToID(i);
for (int g : {id, kWorkerGroup, kWorkerGroup + kServerGroup,
kWorkerGroup + kScheduler,
kWorkerGroup + kServerGroup + kScheduler}) {
node_ids_[g].push_back(id);
}
}
这 5 个id 相对应,即需要在 node_ids_ 这个映射表中对应的 4, 4 + 1, 4 + 2, 4 +1 + 2, 12 这五个 item 之中添加。就是上面代码中的内部 for 循环条件。即,node_ids_ [4], node_ids_ [5],node_ids_ [6],node_ids_ [7] ,node_ids_ [12] 之中,都需要把 12 添加到 vector 最后。
- 12(本身)
- 4(kWorkerGroup)
- 4+1(kWorkerGroup + kScheduler)
- 4+2(kWorkerGroup + kServerGroup)
- 4+1+2,(kWorkerGroup + kServerGroup + kScheduler )
所以,为了实现 “设置 1-7 内任意一个数字 可以发送给其对应的 所有node” 这个功能,对于每一个新节点,需要将其对应多个id(node,node group)上,这些id组就是本节点可以与之通讯的节点。例如对于 worker 2 来说,其 node id 是 2 * 2 + 8 = 12,所以需要将它与
- 1 ~ 7 的 id 表示的是 node group;
- 后续的 id(8,9,10,11 …)表示单个的 node。其中双数 8,10,12… 表示 worker 0, worker 1, worker 2,… 即(2n + 8),9,11,13,…,表示 server 0, server 1,server 2,…,即(2n + 9);
还是花了不少的功夫在讲解node,那么这个node 的标记是用来干啥的?
这些node的标记实际上与我们的worker还有server都是对应的关心,所以通过这些node标记就可以快速找打,这样通信同步一些数据就方便。
在记录完node_id之后,开始调用Van的启动程序。Van其实是一个通信模块。Van的东西还是蛮多的,打算放在下一篇文章里讲了。
在继续就是讲到 Barrier
if (do_barrier) Barrier(customer_id, kWorkerGroup + kServerGroup + kScheduler);
Barrier
同步
总的来讲,schedular节点通过计数的方式实现各个节点的同步。具体来说就是:
- 每个节点在自己指定的命令运行完后会向schedular节点发送一个Control::BARRIER命令的请求并自己阻塞直到收到schedular对应的返回后才解除阻塞;
- schedular节点收到请求后则会在本地计数,看收到的请求数是否和barrier_group的数量是否相等,相等则表示每个机器都运行完指定的命令了,此时schedular节点会向barrier_group的每个机器发送一个返回的信息,并解除其阻塞。
初始化
ps-lite 使用 Barrier 来控制系统的初始化,就是大家都准备好了再一起前进。这是一个可选项。具体如下:
- Scheduler等待所有的worker和server发送BARRIER信息;
- 在完成ADD_NODE后,各个节点会进入指定 group 的Barrier阻塞同步机制(发送 BARRIER 给 Scheduler),以保证上述过程每个节点都已经完成;
- 所有节点(worker和server,包括scheduler) 等待scheduler收到所有节点 BARRIER 信息后的应答;
- 最终所有节点收到scheduler 应答的Barrier message后退出阻塞状态;
等待 BARRIER 消息
Node会调用 Barrier 函数 告知Scheduler,随即自己进入等待状态。
注意,调用时候是
if (do_barrier) Barrier(customer_id, kWorkerGroup + kServerGroup + kScheduler);
复制代码
void Postoffice::Barrier(int customer_id, int node_group) {
if (GetNodeIDs(node_group).size() <= 1) return;
auto role = van_->my_node().role;
if (role == Node::SCHEDULER) {
CHECK(node_group & kScheduler);
} else if (role == Node::WORKER) {
CHECK(node_group & kWorkerGroup);
} else if (role == Node::SERVER) {
CHECK(node_group & kServerGroup);
}
std::unique_lock<std::mutex> ulk(barrier_mu_);
barrier_done_[0][customer_id] = false;
Message req;
req.meta.recver = kScheduler;
req.meta.request = true;
req.meta.control.cmd = Control::BARRIER;
req.meta.app_id = 0;
req.meta.customer_id = customer_id;
req.meta.control.barrier_group = node_group; // 记录了等待哪些
req.meta.timestamp = van_->GetTimestamp();
van_->Send(req); // 给 scheduler 发给 BARRIER
barrier_cond_.wait(ulk, [this, customer_id] { // 然后等待
return barrier_done_[0][customer_id];
});
}
这就是说,等待所有的 group,即 scheduler 节点也要给自己发送消息。
处理 BARRIER 消息
处理等待的动作在 Van 类之中,我们提前放出来。
具体ProcessBarrierCommand逻辑如下:
- 如果 msg->meta.request 为true,说明是 scheduler 收到消息进行处理。
- Scheduler会对Barrier请求进行增加计数。
- 当 Scheduler 收到最后一个请求时(计数等于此group节点总数),则将计数清零,发送结束Barrier的命令。这时候 meta.request 设置为 false;
- 向此group所有节点发送
request==false
的BARRIER
消息。
- 如果 msg->meta.request 为 false,说明是收到消息这个 respones,可以解除barrier了,于是进行处理,调用 Manage 函数 。
- Manage 函数 将app_id对应的所有costomer的
barrier_done_
置为true,然后通知所有等待条件变量barrier_cond_.notify_all()
。
- Manage 函数 将app_id对应的所有costomer的
void Van::ProcessBarrierCommand(Message* msg) {
auto& ctrl = msg->meta.control;
if (msg->meta.request) { // scheduler收到了消息,因为 Postoffice::Barrier函数 会在发送时候做设置为true。
if (barrier_count_.empty()) {
barrier_count_.resize(8, 0);
}
int group = ctrl.barrier_group;
++barrier_count_[group]; // Scheduler会对Barrier请求进行计数
if (barrier_count_[group] ==
static_cast<int>(Postoffice::Get()->GetNodeIDs(group).size())) { // 如果相等,说明已经收到了最后一个请求,所以发送解除 barrier 消息。
barrier_count_[group] = 0;
Message res;
res.meta.request = false; // 回复时候,这里就是false
res.meta.app_id = msg->meta.app_id;
res.meta.customer_id = msg->meta.customer_id;
res.meta.control.cmd = Control::BARRIER;
for (int r : Postoffice::Get()->GetNodeIDs(group)) {
int recver_id = r;
if (shared_node_mapping_.find(r) == shared_node_mapping_.end()) {
res.meta.recver = recver_id;
res.meta.timestamp = timestamp_++;
Send(res);
}
}
}
} else { // 说明这里收到了 barrier respones,可以解除 barrier了。具体见上面的设置为false处。
Postoffice::Get()->Manage(*msg);
}
}
Manage 函数就是解除了 barrier。
void Postoffice::Manage(const Message& recv) {
CHECK(!recv.meta.control.empty());
const auto& ctrl = recv.meta.control;
if (ctrl.cmd == Control::BARRIER && !recv.meta.request) {
barrier_mu_.lock();
auto size = barrier_done_[recv.meta.app_id].size();
for (size_t customer_id = 0; customer_id < size; customer_id++) {
barrier_done_[recv.meta.app_id][customer_id] = true;
}
barrier_mu_.unlock();
barrier_cond_.notify_all(); // 这里解除了barrier
}
}
在上面的启动程序中可能没见到下面两个函数的调用,但是这也是 Postoffice 重要的成员组成
数据key分布式存储
到现在为止,邮车和customer都有了,信件本身无非就是embedding这些参数,但是这些参数的存放也是有讲究的,这也是在上一篇文章中提到的分布式存储,这个分布式是如何体现的?
const std::vector<Range>& Postoffice::GetServerKeyRanges() {
server_key_ranges_mu_.lock();
//循环遍历所有的server,配置server key 的范围
//本质上就是根据server的数量均匀划分而已,就是这么简单
if (server_key_ranges_.empty()) {
for (int i = 0; i < num_servers_; ++i) {
server_key_ranges_.push_back(Range(
kMaxKey / num_servers_ * i,
kMaxKey / num_servers_ * (i+1)));
}
}
server_key_ranges_mu_.unlock();
return server_key_ranges_;
}
通过以上的操作的确解决了数据分布式存储,而且可以明确在worker向server端拉取数据的时候要去哪个server拉数据的问题。
用户管理
现在大概知道了邮车,那么怎么知道要给哪些customer送信件呢?邮局需要管理一份用户的名单。
Customer* Postoffice::GetCustomer(int app_id, int customer_id, int timeout) const {
Customer* obj = nullptr;
for (int i = 0; i < timeout * 1000 + 1; ++i) {
{
std::lock_guard<std::mutex> lk(mu_);
// app_id 是对应 kv存储的id,举个例子FM 里存在一阶weight app_id=0
// 通过app_id 去寻找customer,一般 worker 会有多个thread 对应不同的customer
//但是消费的都是同一个 kv,所以根据app_id可以找到对应的 customer
const auto it = customers_.find(app_id);
if (it != customers_.end()) {
std::unordered_map<int, Customer*> customers_in_app = it->second;
obj = customers_in_app[customer_id];
break;
}
}
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
return obj;
}
这个 GetCustomer 的操作主要是在Van中的 ProcessDataMsg 调用,这里就是Van要把传递的信件交给customer,然后通过 GetCustomer 这个方式来获取相应的customer。
上面的函数列的是读取,还有 AddCustomer 和 RemoveCustomer 负责添加和删除。