Add proper per-file copyright notices/licenses and top-level license.
[bluesky.git] / kvstore / protobufrpc.cc
1 #include "protobufrpc.h"
2
3 #include <google/protobuf/descriptor.h>
4 #include <google/protobuf/io/zero_copy_stream_impl.h>
5 #include <google/protobuf/io/coded_stream.h>
6 #include <boost/shared_ptr.hpp>
7 #include <boost/bind.hpp>
8
9 #include <boost/functional/hash.hpp>
10
11 using namespace std;
12
13 using namespace boost;
14 using asio::buffer;
15
16 namespace bicker
17 {
18
19 template<typename T>
20 static void* void_write(void* data, T val)
21 {
22     *((T*)data) = val;
23     return (char*)data + sizeof(T);
24 }
25
26 class ProtoBufRpcServiceRequest
27 {
28 public:
29     ProtoBufRpcServiceRequest(
30                             RpcController *ctrl,
31                             const MethodDescriptor* method,
32                             Message *request,
33                             Message *response,
34                             shared_ptr<ProtoBufRpcConnection> conn
35                            )
36         :_ctrl(ctrl),
37          _method(method),
38          _request(request),
39          _response(response),
40          _conn(conn)
41     {
42     }
43
44     ~ProtoBufRpcServiceRequest()
45     {
46
47     }
48
49     static void run(ProtoBufRpcServiceRequest *req)
50     {
51
52         req->_conn->writeResponse(req->_response.get());
53
54         delete req;
55     }
56
57     shared_ptr<RpcController> _ctrl;
58     const MethodDescriptor *_method;
59     shared_ptr<Message> _request;
60     shared_ptr<Message> _response;
61     shared_ptr<ProtoBufRpcConnection> _conn;
62 };
63
64 ProtoBufRpcConnection::ProtoBufRpcConnection(asio::io_service& io_service,
65                              Service *service)
66 :_socket(io_service),
67  _strand(io_service),
68  _service(service),
69  _state(STATE_NONE)
70 {
71 }
72
73 tcp::socket& ProtoBufRpcConnection::socket()
74 {
75     return _socket;
76 }
77
78 void ProtoBufRpcConnection::start()
79 {
80     _socket.async_read_some(_buffer.prepare(4096),
81             _strand.wrap(
82                 boost::bind(&ProtoBufRpcConnection::handle_read, shared_from_this(),
83                             asio::placeholders::error,
84                             asio::placeholders::bytes_transferred)));
85 }
86
87 void ProtoBufRpcConnection::writeResponse(Message *msg)
88 {
89     int rlen = msg->ByteSize();
90     int len = htonl(rlen);
91     int mlen = sizeof(len) + rlen;
92
93     void * data = asio::buffer_cast<void*>(_buffer.prepare(mlen));
94
95     data = void_write(data, len);
96
97     using google::protobuf::io::ArrayOutputStream;
98
99     ArrayOutputStream as(data, rlen);
100
101     msg->SerializeToZeroCopyStream(&as);
102
103     _buffer.commit(mlen);
104
105     asio::async_write(_socket,
106             _buffer.data(),
107             _strand.wrap(
108                 boost::bind(&ProtoBufRpcConnection::handle_write, 
109                             shared_from_this(),
110                 asio::placeholders::error,
111                 asio::placeholders::bytes_transferred)));
112 }
113
114
115 void ProtoBufRpcConnection::handle_read(const error_code& e, 
116                  std::size_t bytes_transferred)
117 {
118     if (!e)
119     {
120         _buffer.commit(bytes_transferred);
121
122         if (_state == STATE_NONE)
123         {
124             if (_buffer.size() >= sizeof(_id) + sizeof(_len))
125             {
126                 string b(
127                      buffers_begin(_buffer.data()),
128                      buffers_begin(_buffer.data())
129                                                 + sizeof(_id) + sizeof(_len)
130                         );
131
132                 _buffer.consume(sizeof(_id) + sizeof(_len));
133
134                 _id = *((int*)b.c_str());
135                 _id = ntohl(_id);
136
137                 _len = *((unsigned int*)(b.c_str() + sizeof(_id)));
138                 _len = ntohl(_len);
139
140                 _state = STATE_HAVE_ID_AND_LEN;
141             }
142             else
143             {
144                 start();
145             }
146         }
147
148         if (_state == STATE_HAVE_ID_AND_LEN || _state == STATE_WAITING_FOR_DATA)
149         {
150             if (_buffer.size() >= _len)
151             {
152                 const MethodDescriptor* method =
153                     _service->GetDescriptor()->method(_id);
154
155                 Message *req = _service->GetRequestPrototype(method).New();
156                 Message *resp = _service->GetResponsePrototype(method).New();
157
158                 using google::protobuf::io::ArrayInputStream;
159                 using google::protobuf::io::CodedInputStream;
160
161                 const void* data = asio::buffer_cast<const void*>(
162                                                         _buffer.data()
163                                                                  );
164                 ArrayInputStream as(data, _len);
165                 CodedInputStream is(&as);
166                 is.SetTotalBytesLimit(512 * 1024 * 1024, -1);
167
168                 if (!req->ParseFromCodedStream(&is))
169                 {
170                     throw std::runtime_error("ParseFromCodedStream");
171                 }
172
173                 _buffer.consume(_len);
174
175                 ProtoBufRpcController *ctrl = new ProtoBufRpcController();
176                 _service->CallMethod(method, 
177                                      ctrl,
178                                      req, 
179                                      resp, 
180                                      NewCallback(
181                                              &ProtoBufRpcServiceRequest::run,
182                                              new ProtoBufRpcServiceRequest(
183                                                            ctrl,
184                                                            method,
185                                                            req,
186                                                            resp,
187                                                            shared_from_this())
188                                                 )
189                                      );
190                 _state = STATE_NONE;
191             }
192             else
193             {
194                 _state = STATE_WAITING_FOR_DATA;
195                 start();
196             }
197         }
198
199     }
200     else
201     {
202         error_code ignored_ec;
203         _socket.shutdown(tcp::socket::shutdown_both, ignored_ec);
204     }
205 }
206
207 void ProtoBufRpcConnection::handle_write(const error_code& e,
208                                          std::size_t bytes_transferred)
209 {
210     if (e)
211     {
212         error_code ignored_ec;
213         _socket.shutdown(tcp::socket::shutdown_both, ignored_ec);
214     }
215     else
216     {
217         _buffer.consume(bytes_transferred);
218
219         if (_buffer.size())
220         {
221             asio::async_write(_socket,
222                     _buffer.data(),
223                     _strand.wrap(
224                         boost::bind(&ProtoBufRpcConnection::handle_write, 
225                                     shared_from_this(),
226                         asio::placeholders::error,
227                         asio::placeholders::bytes_transferred)));
228             return;
229         }
230
231         _state = STATE_NONE;
232         start();
233     }
234 }
235
236 ProtoBufRpcServer::ProtoBufRpcServer()
237     :_io_service(new asio::io_service())
238 {
239 }
240
241 bool ProtoBufRpcServer::registerService(uint16_t port,
242                                 shared_ptr<Service> service)
243 {
244     // This is not thread safe
245
246     // The RegisteredService Constructor fires up the appropriate
247     // async accepts for the service
248     _services.push_back(shared_ptr<RegisteredService>(
249                                         new RegisteredService(
250                                                     _io_service,
251                                                     port,
252                                                     service)));
253
254     return true;
255 }
256
257 void run_wrapper(asio::io_service *io_service)
258 {
259     struct itimerval itimer; 
260     setitimer(ITIMER_PROF, &itimer, NULL);
261
262     io_service->run();
263 }
264
265 void ProtoBufRpcServer::run()
266 {
267     try
268     {
269         if (_services.size() == 0)
270         {
271             throw std::runtime_error("No services registered for ProtoBufRpcServer");
272         }
273
274         size_t nprocs = sysconf(_SC_NPROCESSORS_ONLN);
275
276         vector<shared_ptr<thread> > threads;
277         for (size_t i = 0; i < nprocs; ++i)
278         {
279             shared_ptr<thread> t(new thread(
280                                     boost::bind(
281                                         //&run_wrapper,
282                                         &asio::io_service::run, 
283                                         _io_service.get())));
284             threads.push_back(t);
285         }
286
287         for (size_t i = 0; i < threads.size(); ++i)
288         {
289             threads[i]->join();
290         }
291     }
292     catch (std::exception &e)
293     {
294         std::cerr << "ProtoBufRpcService" << e.what() << std::endl;
295     }
296 }
297
298 void ProtoBufRpcServer::shutdown()
299 {
300     _io_service->stop();
301 }
302
303 ProtoBufRpcServer::RegisteredService::RegisteredService(
304                   shared_ptr<asio::io_service> io_service,
305                   uint16_t port,
306                   shared_ptr<Service> service
307                  )
308 :_io_service(io_service),
309  _port(port),
310  _service(service),
311  _endpoint(tcp::v4(), _port),
312  _acceptor(*_io_service),
313  _new_connection(new ProtoBufRpcConnection(*_io_service, _service.get()))
314 {
315     _acceptor.open(_endpoint.protocol());
316     _acceptor.set_option(tcp::acceptor::reuse_address(true));
317     _acceptor.bind(_endpoint);
318     _acceptor.listen();
319     _acceptor.async_accept(_new_connection->socket(),
320                    boost::bind(&ProtoBufRpcServer::RegisteredService::handle_accept, 
321                                this, 
322                                asio::placeholders::error));
323 }
324
325 void ProtoBufRpcServer::RegisteredService::handle_accept(const error_code& e)
326 {
327       if (!e)
328       {
329           _new_connection->start();
330           _new_connection.reset(new ProtoBufRpcConnection(*_io_service, _service.get()));
331           _acceptor.async_accept(_new_connection->socket(),
332                  boost::bind(&ProtoBufRpcServer::RegisteredService::handle_accept,
333                              this,
334                              asio::placeholders::error));
335       }
336
337 }
338
339 ProtoBufRpcController::ProtoBufRpcController()
340 {
341 }
342
343 ProtoBufRpcController::~ProtoBufRpcController()
344 {
345 }
346
347 void ProtoBufRpcController::Reset()
348 {
349 }
350
351 bool ProtoBufRpcController::Failed() const
352 {
353     return false;
354 }
355
356 string ProtoBufRpcController::ErrorText() const
357 {
358     return "No Error";
359 }
360
361 void ProtoBufRpcController::StartCancel()
362 {
363 }
364
365 void ProtoBufRpcController::SetFailed(const string &/*reason*/)
366 {
367 }
368
369 bool ProtoBufRpcController::IsCanceled() const
370 {
371     return false;
372 }
373
374 void ProtoBufRpcController::NotifyOnCancel(Closure * /*callback*/)
375 {
376 }
377
378 class ProtoBufRpcChannel::MethodHandler 
379     : public enable_shared_from_this<MethodHandler>,
380       private boost::noncopyable
381 {
382 public:
383     MethodHandler(auto_ptr<SocketCheckout> socket,
384                            const MethodDescriptor * method,
385                            RpcController * controller,
386                            const Message * request,
387                            Message * response,
388                            Closure * done
389                           )
390         :_socket(socket),
391         _method(method),
392         _controller(controller),
393         _request(request),
394         _response(response),
395         _done(done)
396     {
397     }
398
399     ~MethodHandler()
400     {
401         _socket.reset();
402         _done->Run();
403     }
404
405     static void execute(shared_ptr<MethodHandler> this_ptr)
406     {
407         int index = htonl(this_ptr->_method->index());
408         int rlen = this_ptr->_request->ByteSize();
409         int len = htonl(rlen);
410
411         int mlen = sizeof(index) + sizeof(len) + rlen;
412
413         void * data = asio::buffer_cast<void*>(this_ptr->_buffer.prepare(mlen));
414
415         data = void_write(data, index);
416         data = void_write(data, len);
417
418         using google::protobuf::io::ArrayOutputStream;
419
420         ArrayOutputStream as(data, rlen);
421
422         this_ptr->_request->SerializeToZeroCopyStream(&as);
423         this_ptr->_buffer.commit(mlen);
424
425         (*(this_ptr->_socket))->async_send(this_ptr->_buffer.data(),
426                                boost::bind(&ProtoBufRpcChannel::MethodHandler::handle_write,
427                                            this_ptr,
428                                            asio::placeholders::error,
429                                            asio::placeholders::bytes_transferred));
430     }
431
432     static void handle_write(shared_ptr<MethodHandler> this_ptr,
433                       const error_code& e, 
434                       std::size_t bytes_transferred)
435     {
436         if (!e)
437         {
438             this_ptr->_buffer.consume(bytes_transferred);
439
440             if (this_ptr->_buffer.size())
441             {
442                 (*(this_ptr->_socket))->async_send(this_ptr->_buffer.data(),
443                                        boost::bind(&ProtoBufRpcChannel::MethodHandler::handle_write,
444                                                    this_ptr,
445                                                    asio::placeholders::error,
446                                                    asio::placeholders::bytes_transferred));
447                 return;
448             }
449
450             (*(this_ptr->_socket))->async_receive(
451                                       buffer(&this_ptr->_len, sizeof(this_ptr->_len)),
452                                       boost::bind(
453                                                   &ProtoBufRpcChannel::MethodHandler::handle_read_len,
454                                                   this_ptr,
455                                                   asio::placeholders::error,
456                                                   asio::placeholders::bytes_transferred)
457                                      );
458         }
459         else
460         {
461             this_ptr->_controller->SetFailed(e.message());
462             (*(this_ptr->_socket))->close();
463         }
464     }
465
466     static void handle_read_len(shared_ptr<MethodHandler> this_ptr,
467                                 const error_code& e,
468                                 std::size_t bytes_transferred)
469     {
470         if (!e && bytes_transferred == sizeof(this_ptr->_len))
471         {
472             this_ptr->_len = ntohl(this_ptr->_len);
473             (*(this_ptr->_socket))->async_receive(
474                                       this_ptr->_buffer.prepare(this_ptr->_len),
475                                       boost::bind(
476                                                   &ProtoBufRpcChannel::MethodHandler::handle_read_response,
477                                                   this_ptr,
478                                                   asio::placeholders::error,
479                                                   asio::placeholders::bytes_transferred
480                                                  )
481                                      );
482         }
483         else
484         {
485             this_ptr->_controller->SetFailed(e.message());
486             (*(this_ptr->_socket))->close();
487         }
488     }
489
490     static void handle_read_response(shared_ptr<MethodHandler> this_ptr,
491                               const error_code& e, 
492                               std::size_t bytes_transferred)
493     {
494         if (!e)
495         {
496             this_ptr->_buffer.commit(bytes_transferred);
497             if (this_ptr->_buffer.size() >= this_ptr->_len)
498             {
499                 using google::protobuf::io::ArrayInputStream;
500                 using google::protobuf::io::CodedInputStream;
501
502                 const void* data = asio::buffer_cast<const void*>(
503                                                         this_ptr->_buffer.data()
504                                                                  );
505                 ArrayInputStream as(data, this_ptr->_len);
506                 CodedInputStream is(&as);
507                 is.SetTotalBytesLimit(512 * 1024 * 1024, -1);
508
509                 if (!this_ptr->_response->ParseFromCodedStream(&is))
510                 {
511                     throw std::runtime_error("ParseFromCodedStream");
512                 }
513
514                 this_ptr->_buffer.consume(this_ptr->_len);
515             }
516             else
517             {
518                 (*(this_ptr->_socket))->async_receive(
519                                           this_ptr->_buffer.prepare(this_ptr->_len - this_ptr->_buffer.size()),
520                                           boost::bind(
521                                                       &ProtoBufRpcChannel::MethodHandler::handle_read_response,
522                                                       this_ptr,
523                                                       asio::placeholders::error,
524                                                       asio::placeholders::bytes_transferred
525                                                      )
526                                          );
527                 return;
528             }
529         }
530         else
531         {
532             this_ptr->_controller->SetFailed(e.message());
533             (*(this_ptr->_socket))->close();
534         }
535     }
536
537 private:
538     auto_ptr<SocketCheckout> _socket;
539     const MethodDescriptor * _method;
540     RpcController * _controller;
541     const Message * _request;
542     Message * _response;
543     Closure * _done;
544     asio::streambuf _buffer;
545     unsigned int _len;
546     bool _status;
547     unsigned int _sent;
548 };
549
550
551 ProtoBufRpcChannel::ProtoBufRpcChannel(const string &remotehost, 
552                                   const string &port)
553     :_remote_host(remotehost), _port(port),
554      _resolver(_io_service),
555      _acceptor(_io_service),
556      _pool(2000, _io_service),
557      _lame_socket(_io_service),
558      _thread()
559 //                                &asio::io_service::run, 
560 //                                &_io_service)))
561 {
562
563     tcp::resolver::query query(_remote_host, _port);
564     tcp::resolver::iterator endpoint_iterator = _resolver.resolve(query);
565     tcp::resolver::iterator end;
566
567     error_code error = asio::error::host_not_found;
568
569     if (endpoint_iterator == end) throw syserr::system_error(error);
570
571     _pool.setEndpoint(*endpoint_iterator);
572
573     tcp::endpoint e(tcp::v4(), 0);
574     _acceptor.open(e.protocol());
575     _acceptor.set_option(tcp::acceptor::reuse_address(true));
576     _acceptor.bind(e);
577     _acceptor.listen();
578     _acceptor.async_accept(_lame_socket,
579                        boost::bind(&ProtoBufRpcChannel::lame_handle_accept, this, 
580                                    asio::placeholders::error));
581
582     _thread = shared_ptr<thread>(new thread(
583                                      boost::bind(
584                                              &asio::io_service::run, 
585                                              &_io_service)));
586 }
587
588 void ProtoBufRpcChannel::lame_handle_accept(const error_code &err)
589 {
590     if (!err)
591     {
592         _acceptor.async_accept(_lame_socket,
593                            boost::bind(&ProtoBufRpcChannel::lame_handle_accept,
594                                        this,
595                                        asio::placeholders::error));
596     }
597 }
598
599 ProtoBufRpcChannel::~ProtoBufRpcChannel()
600 {
601     _pool.cancelAndClear();
602
603     _io_service.stop();
604
605     _thread->join();
606 }
607
608 void ProtoBufRpcChannel::CallMethod(
609         const MethodDescriptor * method,
610         RpcController * controller,
611         const Message * request,
612         Message * response,
613         Closure * done)
614 {
615     shared_ptr<MethodHandler> h(
616                             new MethodHandler(
617                           auto_ptr<SocketCheckout>(new SocketCheckout(&_pool)),
618                                               method,
619                                               controller,
620                                               request,
621                                               response,
622                                               done
623                                              ));
624
625     MethodHandler::execute(h);
626 }
627
628 } // namespace bicker