1 #include "protobufrpc.h"
3 #include <google/protobuf/descriptor.h>
4 #include <google/protobuf/io/zero_copy_stream_impl.h>
5 #include <google/protobuf/io/coded_stream.h>
6 #include <boost/shared_ptr.hpp>
7 #include <boost/bind.hpp>
9 #include <boost/functional/hash.hpp>
13 using namespace boost;
20 static void* void_write(void* data, T val)
23 return (char*)data + sizeof(T);
26 class ProtoBufRpcServiceRequest
29 ProtoBufRpcServiceRequest(
31 const MethodDescriptor* method,
34 shared_ptr<ProtoBufRpcConnection> conn
44 ~ProtoBufRpcServiceRequest()
49 static void run(ProtoBufRpcServiceRequest *req)
52 req->_conn->writeResponse(req->_response.get());
57 shared_ptr<RpcController> _ctrl;
58 const MethodDescriptor *_method;
59 shared_ptr<Message> _request;
60 shared_ptr<Message> _response;
61 shared_ptr<ProtoBufRpcConnection> _conn;
64 ProtoBufRpcConnection::ProtoBufRpcConnection(asio::io_service& io_service,
73 tcp::socket& ProtoBufRpcConnection::socket()
78 void ProtoBufRpcConnection::start()
80 _socket.async_read_some(_buffer.prepare(4096),
82 boost::bind(&ProtoBufRpcConnection::handle_read, shared_from_this(),
83 asio::placeholders::error,
84 asio::placeholders::bytes_transferred)));
87 void ProtoBufRpcConnection::writeResponse(Message *msg)
89 int rlen = msg->ByteSize();
90 int len = htonl(rlen);
91 int mlen = sizeof(len) + rlen;
93 void * data = asio::buffer_cast<void*>(_buffer.prepare(mlen));
95 data = void_write(data, len);
97 using google::protobuf::io::ArrayOutputStream;
99 ArrayOutputStream as(data, rlen);
101 msg->SerializeToZeroCopyStream(&as);
103 _buffer.commit(mlen);
105 asio::async_write(_socket,
108 boost::bind(&ProtoBufRpcConnection::handle_write,
110 asio::placeholders::error,
111 asio::placeholders::bytes_transferred)));
115 void ProtoBufRpcConnection::handle_read(const error_code& e,
116 std::size_t bytes_transferred)
120 _buffer.commit(bytes_transferred);
122 if (_state == STATE_NONE)
124 if (_buffer.size() >= sizeof(_id) + sizeof(_len))
127 buffers_begin(_buffer.data()),
128 buffers_begin(_buffer.data())
129 + sizeof(_id) + sizeof(_len)
132 _buffer.consume(sizeof(_id) + sizeof(_len));
134 _id = *((int*)b.c_str());
137 _len = *((unsigned int*)(b.c_str() + sizeof(_id)));
140 _state = STATE_HAVE_ID_AND_LEN;
148 if (_state == STATE_HAVE_ID_AND_LEN || _state == STATE_WAITING_FOR_DATA)
150 if (_buffer.size() >= _len)
152 const MethodDescriptor* method =
153 _service->GetDescriptor()->method(_id);
155 Message *req = _service->GetRequestPrototype(method).New();
156 Message *resp = _service->GetResponsePrototype(method).New();
158 using google::protobuf::io::ArrayInputStream;
159 using google::protobuf::io::CodedInputStream;
161 const void* data = asio::buffer_cast<const void*>(
164 ArrayInputStream as(data, _len);
165 CodedInputStream is(&as);
166 is.SetTotalBytesLimit(512 * 1024 * 1024, -1);
168 if (!req->ParseFromCodedStream(&is))
170 throw std::runtime_error("ParseFromCodedStream");
173 _buffer.consume(_len);
175 ProtoBufRpcController *ctrl = new ProtoBufRpcController();
176 _service->CallMethod(method,
181 &ProtoBufRpcServiceRequest::run,
182 new ProtoBufRpcServiceRequest(
194 _state = STATE_WAITING_FOR_DATA;
202 error_code ignored_ec;
203 _socket.shutdown(tcp::socket::shutdown_both, ignored_ec);
207 void ProtoBufRpcConnection::handle_write(const error_code& e,
208 std::size_t bytes_transferred)
212 error_code ignored_ec;
213 _socket.shutdown(tcp::socket::shutdown_both, ignored_ec);
217 _buffer.consume(bytes_transferred);
221 asio::async_write(_socket,
224 boost::bind(&ProtoBufRpcConnection::handle_write,
226 asio::placeholders::error,
227 asio::placeholders::bytes_transferred)));
236 ProtoBufRpcServer::ProtoBufRpcServer()
237 :_io_service(new asio::io_service())
241 bool ProtoBufRpcServer::registerService(uint16_t port,
242 shared_ptr<Service> service)
244 // This is not thread safe
246 // The RegisteredService Constructor fires up the appropriate
247 // async accepts for the service
248 _services.push_back(shared_ptr<RegisteredService>(
249 new RegisteredService(
257 void run_wrapper(asio::io_service *io_service)
259 struct itimerval itimer;
260 setitimer(ITIMER_PROF, &itimer, NULL);
265 void ProtoBufRpcServer::run()
269 if (_services.size() == 0)
271 throw std::runtime_error("No services registered for ProtoBufRpcServer");
274 size_t nprocs = sysconf(_SC_NPROCESSORS_ONLN);
276 vector<shared_ptr<thread> > threads;
277 for (size_t i = 0; i < nprocs; ++i)
279 shared_ptr<thread> t(new thread(
282 &asio::io_service::run,
283 _io_service.get())));
284 threads.push_back(t);
287 for (size_t i = 0; i < threads.size(); ++i)
292 catch (std::exception &e)
294 std::cerr << "ProtoBufRpcService" << e.what() << std::endl;
298 void ProtoBufRpcServer::shutdown()
303 ProtoBufRpcServer::RegisteredService::RegisteredService(
304 shared_ptr<asio::io_service> io_service,
306 shared_ptr<Service> service
308 :_io_service(io_service),
311 _endpoint(tcp::v4(), _port),
312 _acceptor(*_io_service),
313 _new_connection(new ProtoBufRpcConnection(*_io_service, _service.get()))
315 _acceptor.open(_endpoint.protocol());
316 _acceptor.set_option(tcp::acceptor::reuse_address(true));
317 _acceptor.bind(_endpoint);
319 _acceptor.async_accept(_new_connection->socket(),
320 boost::bind(&ProtoBufRpcServer::RegisteredService::handle_accept,
322 asio::placeholders::error));
325 void ProtoBufRpcServer::RegisteredService::handle_accept(const error_code& e)
329 _new_connection->start();
330 _new_connection.reset(new ProtoBufRpcConnection(*_io_service, _service.get()));
331 _acceptor.async_accept(_new_connection->socket(),
332 boost::bind(&ProtoBufRpcServer::RegisteredService::handle_accept,
334 asio::placeholders::error));
339 ProtoBufRpcController::ProtoBufRpcController()
343 ProtoBufRpcController::~ProtoBufRpcController()
347 void ProtoBufRpcController::Reset()
351 bool ProtoBufRpcController::Failed() const
356 string ProtoBufRpcController::ErrorText() const
361 void ProtoBufRpcController::StartCancel()
365 void ProtoBufRpcController::SetFailed(const string &/*reason*/)
369 bool ProtoBufRpcController::IsCanceled() const
374 void ProtoBufRpcController::NotifyOnCancel(Closure * /*callback*/)
378 class ProtoBufRpcChannel::MethodHandler
379 : public enable_shared_from_this<MethodHandler>,
380 private boost::noncopyable
383 MethodHandler(auto_ptr<SocketCheckout> socket,
384 const MethodDescriptor * method,
385 RpcController * controller,
386 const Message * request,
392 _controller(controller),
405 static void execute(shared_ptr<MethodHandler> this_ptr)
407 int index = htonl(this_ptr->_method->index());
408 int rlen = this_ptr->_request->ByteSize();
409 int len = htonl(rlen);
411 int mlen = sizeof(index) + sizeof(len) + rlen;
413 void * data = asio::buffer_cast<void*>(this_ptr->_buffer.prepare(mlen));
415 data = void_write(data, index);
416 data = void_write(data, len);
418 using google::protobuf::io::ArrayOutputStream;
420 ArrayOutputStream as(data, rlen);
422 this_ptr->_request->SerializeToZeroCopyStream(&as);
423 this_ptr->_buffer.commit(mlen);
425 (*(this_ptr->_socket))->async_send(this_ptr->_buffer.data(),
426 boost::bind(&ProtoBufRpcChannel::MethodHandler::handle_write,
428 asio::placeholders::error,
429 asio::placeholders::bytes_transferred));
432 static void handle_write(shared_ptr<MethodHandler> this_ptr,
434 std::size_t bytes_transferred)
438 this_ptr->_buffer.consume(bytes_transferred);
440 if (this_ptr->_buffer.size())
442 (*(this_ptr->_socket))->async_send(this_ptr->_buffer.data(),
443 boost::bind(&ProtoBufRpcChannel::MethodHandler::handle_write,
445 asio::placeholders::error,
446 asio::placeholders::bytes_transferred));
450 (*(this_ptr->_socket))->async_receive(
451 buffer(&this_ptr->_len, sizeof(this_ptr->_len)),
453 &ProtoBufRpcChannel::MethodHandler::handle_read_len,
455 asio::placeholders::error,
456 asio::placeholders::bytes_transferred)
461 this_ptr->_controller->SetFailed(e.message());
462 (*(this_ptr->_socket))->close();
466 static void handle_read_len(shared_ptr<MethodHandler> this_ptr,
468 std::size_t bytes_transferred)
470 if (!e && bytes_transferred == sizeof(this_ptr->_len))
472 this_ptr->_len = ntohl(this_ptr->_len);
473 (*(this_ptr->_socket))->async_receive(
474 this_ptr->_buffer.prepare(this_ptr->_len),
476 &ProtoBufRpcChannel::MethodHandler::handle_read_response,
478 asio::placeholders::error,
479 asio::placeholders::bytes_transferred
485 this_ptr->_controller->SetFailed(e.message());
486 (*(this_ptr->_socket))->close();
490 static void handle_read_response(shared_ptr<MethodHandler> this_ptr,
492 std::size_t bytes_transferred)
496 this_ptr->_buffer.commit(bytes_transferred);
497 if (this_ptr->_buffer.size() >= this_ptr->_len)
499 using google::protobuf::io::ArrayInputStream;
500 using google::protobuf::io::CodedInputStream;
502 const void* data = asio::buffer_cast<const void*>(
503 this_ptr->_buffer.data()
505 ArrayInputStream as(data, this_ptr->_len);
506 CodedInputStream is(&as);
507 is.SetTotalBytesLimit(512 * 1024 * 1024, -1);
509 if (!this_ptr->_response->ParseFromCodedStream(&is))
511 throw std::runtime_error("ParseFromCodedStream");
514 this_ptr->_buffer.consume(this_ptr->_len);
518 (*(this_ptr->_socket))->async_receive(
519 this_ptr->_buffer.prepare(this_ptr->_len - this_ptr->_buffer.size()),
521 &ProtoBufRpcChannel::MethodHandler::handle_read_response,
523 asio::placeholders::error,
524 asio::placeholders::bytes_transferred
532 this_ptr->_controller->SetFailed(e.message());
533 (*(this_ptr->_socket))->close();
538 auto_ptr<SocketCheckout> _socket;
539 const MethodDescriptor * _method;
540 RpcController * _controller;
541 const Message * _request;
544 asio::streambuf _buffer;
551 ProtoBufRpcChannel::ProtoBufRpcChannel(const string &remotehost,
553 :_remote_host(remotehost), _port(port),
554 _resolver(_io_service),
555 _acceptor(_io_service),
556 _pool(2000, _io_service),
557 _lame_socket(_io_service),
559 // &asio::io_service::run,
563 tcp::resolver::query query(_remote_host, _port);
564 tcp::resolver::iterator endpoint_iterator = _resolver.resolve(query);
565 tcp::resolver::iterator end;
567 error_code error = asio::error::host_not_found;
569 if (endpoint_iterator == end) throw syserr::system_error(error);
571 _pool.setEndpoint(*endpoint_iterator);
573 tcp::endpoint e(tcp::v4(), 0);
574 _acceptor.open(e.protocol());
575 _acceptor.set_option(tcp::acceptor::reuse_address(true));
578 _acceptor.async_accept(_lame_socket,
579 boost::bind(&ProtoBufRpcChannel::lame_handle_accept, this,
580 asio::placeholders::error));
582 _thread = shared_ptr<thread>(new thread(
584 &asio::io_service::run,
588 void ProtoBufRpcChannel::lame_handle_accept(const error_code &err)
592 _acceptor.async_accept(_lame_socket,
593 boost::bind(&ProtoBufRpcChannel::lame_handle_accept,
595 asio::placeholders::error));
599 ProtoBufRpcChannel::~ProtoBufRpcChannel()
601 _pool.cancelAndClear();
608 void ProtoBufRpcChannel::CallMethod(
609 const MethodDescriptor * method,
610 RpcController * controller,
611 const Message * request,
615 shared_ptr<MethodHandler> h(
617 auto_ptr<SocketCheckout>(new SocketCheckout(&_pool)),
625 MethodHandler::execute(h);
628 } // namespace bicker