summaryrefslogtreecommitdiff
path: root/caffe2/utils/zmq_helper.h
blob: bd45be9192dcad078b5c7248be3b853dd7f907eb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#ifndef CAFFE2_UTILS_ZMQ_HELPER_H_
#define CAFFE2_UTILS_ZMQ_HELPER_H_

#include <zmq.h>

#include "caffe2/core/logging.h"

namespace caffe2 {

class ZmqContext {
 public:
  explicit ZmqContext(int io_threads) : ptr_(zmq_ctx_new()) {
    CAFFE_ENFORCE(ptr_ != nullptr, "Failed to create zmq context.");
    int rc = zmq_ctx_set(ptr_, ZMQ_IO_THREADS, io_threads);
    CAFFE_ENFORCE_EQ(rc, 0);
    rc = zmq_ctx_set(ptr_, ZMQ_MAX_SOCKETS, ZMQ_MAX_SOCKETS_DFLT);
    CAFFE_ENFORCE_EQ(rc, 0);
  }
  ~ZmqContext() {
    int rc = zmq_ctx_destroy(ptr_);
    CAFFE_ENFORCE_EQ(rc, 0);
  }

  void* ptr() { return ptr_; }

 private:
  void* ptr_;

  C10_DISABLE_COPY_AND_ASSIGN(ZmqContext);
};

class ZmqMessage {
 public:
  ZmqMessage() {
    int rc = zmq_msg_init(&msg_);
    CAFFE_ENFORCE_EQ(rc, 0);
  }

  ~ZmqMessage() {
    int rc = zmq_msg_close(&msg_);
    CAFFE_ENFORCE_EQ(rc, 0);
  }

  zmq_msg_t* msg() { return &msg_; }

  void* data() { return zmq_msg_data(&msg_); }
  size_t size() { return zmq_msg_size(&msg_); }

 private:
  zmq_msg_t msg_;
  C10_DISABLE_COPY_AND_ASSIGN(ZmqMessage);
};

class ZmqSocket {
 public:
  explicit ZmqSocket(int type)
      : context_(1), ptr_(zmq_socket(context_.ptr(), type)) {
    CAFFE_ENFORCE(ptr_ != nullptr, "Faild to create zmq socket.");
  }

  ~ZmqSocket() {
    int rc = zmq_close(ptr_);
    CAFFE_ENFORCE_EQ(rc, 0);
  }

  void Bind(const string& addr) {
    int rc = zmq_bind(ptr_, addr.c_str());
    CAFFE_ENFORCE_EQ(rc, 0);
  }

  void Unbind(const string& addr) {
    int rc = zmq_unbind(ptr_, addr.c_str());
    CAFFE_ENFORCE_EQ(rc, 0);
  }

  void Connect(const string& addr) {
    int rc = zmq_connect(ptr_, addr.c_str());
    CAFFE_ENFORCE_EQ(rc, 0);
  }

  void Disconnect(const string& addr) {
    int rc = zmq_disconnect(ptr_, addr.c_str());
    CAFFE_ENFORCE_EQ(rc, 0);
  }

  int Send(const string& msg, int flags) {
    int nbytes = zmq_send(ptr_, msg.c_str(), msg.size(), flags);
    if (nbytes) {
      return nbytes;
    } else if (zmq_errno() == EAGAIN) {
      return 0;
    } else {
      LOG(FATAL) << "Cannot send zmq message. Error number: "
                      << zmq_errno();
      return 0;
    }
  }

  int SendTillSuccess(const string& msg, int flags) {
    CAFFE_ENFORCE(msg.size(), "You cannot send an empty message.");
    int nbytes = 0;
    do {
      nbytes = Send(msg, flags);
    } while (nbytes == 0);
    return nbytes;
  }

  int Recv(ZmqMessage* msg) {
    int nbytes = zmq_msg_recv(msg->msg(), ptr_, 0);
    if (nbytes >= 0) {
      return nbytes;
    } else if (zmq_errno() == EAGAIN || zmq_errno() == EINTR) {
      return 0;
    } else {
      LOG(FATAL) << "Cannot receive zmq message. Error number: "
                      << zmq_errno();
      return 0;
    }
  }

  int RecvTillSuccess(ZmqMessage* msg) {
    int nbytes = 0;
    do {
      nbytes = Recv(msg);
    } while (nbytes == 0);
    return nbytes;
  }

 private:
  ZmqContext context_;
  void* ptr_;
};

}  // namespace caffe2


#endif  // CAFFE2_UTILS_ZMQ_HELPER_H_