diff --git a/examples/rpc.lua b/examples/rpc.lua index 3fc7f17..0c71957 100644 --- a/examples/rpc.lua +++ b/examples/rpc.lua @@ -1,6 +1,6 @@ local core = require "core" local crypto = require "core.crypto" -local rpc = require "cluster.rpc" +local cluster = require "core.cluster" local zproto = require "zproto" local proto = zproto:parse [[ @@ -12,10 +12,33 @@ pong 0x2 { } ]] +assert(proto) +local function unmarshal(cmd, buf, size) + local dat, size = proto:unpack(buf, size, true) + local body = proto:decode(cmd, dat, size) + return body +end -local server = rpc.listen { - addr = "127.0.0.1:9999", - proto = proto, +local function marshal(cmd, body) + if type(cmd) == "string" then + cmd = proto:tag(cmd) + end + local dat, size = proto:encode(cmd, body, true) + local buf, size = proto:pack(dat, size, true) + return cmd, buf, size +end + +local function call(msg, fd) + print("callee", msg.txt, fd) + return "pong", msg +end + +local router = setmetatable({}, {__index = function(t, k) return call end}) + +local server = cluster.new { + marshal = marshal, + unmarshal = unmarshal, + router = router, accept = function(fd, addr) print("accept", fd, addr) end, @@ -23,28 +46,28 @@ local server = rpc.listen { close = function(fd, errno) print("close", fd, errno) end, - - call = function(msg, cmd, fd) - print("callee", msg.txt, cmd, fd) - return "pong", msg - end } +server.listen("127.0.0.1:9999") + +local client = cluster.new { + marshal = marshal, + unmarshal = unmarshal, + router = router, + close = function(fd, errno) + print("close", fd, errno) + end, +} core.start(function() for i = 1, 3 do core.fork(function() - local conn = rpc.connect { - addr = "127.0.0.1:9999", - proto = proto, - timeout = 5000, - close = function(fd, errno) - end, - } - while true do + local fd, err = client.connect("127.0.0.1:9999") + print("connect", fd, err) + for j = 1, 10000 do local txt = crypto.randomkey(5) - local ack, cmd = conn:call("ping", {txt = txt}) - print("caller", conn, txt, ack.txt) + local ack, cmd = client.ping(fd, {txt = txt}) + print("caller", fd, txt, ack.txt) assert(ack.txt == txt) assert(cmd == proto:tag("pong")) core.sleep(1000) diff --git a/lualib-src/lualib-core.c b/lualib-src/lualib-core.c index 341ff99..e594189 100644 --- a/lualib-src/lualib-core.c +++ b/lualib-src/lualib-core.c @@ -459,8 +459,8 @@ lsendsize(lua_State *L) static int ltracespan(lua_State *L) { - silly_trace_span_t span; - span = (silly_trace_span_t)luaL_checkinteger(L, 1); + silly_tracespan_t span; + span = (silly_tracespan_t)luaL_checkinteger(L, 1); silly_trace_span(span); return 0; } @@ -468,7 +468,7 @@ ltracespan(lua_State *L) static int ltracenew(lua_State *L) { - silly_trace_id_t traceid; + silly_traceid_t traceid; traceid = silly_trace_new(); lua_pushinteger(L, (lua_Integer)traceid); return 1; @@ -477,13 +477,13 @@ ltracenew(lua_State *L) static int ltraceset(lua_State *L) { - silly_trace_id_t traceid; + silly_traceid_t traceid; lua_State *co = lua_tothread(L, 1); silly_worker_resume(co); if lua_isnoneornil(L, 2) { traceid = TRACE_WORKER_ID; } else { - traceid = (silly_trace_id_t)luaL_checkinteger(L, 2); + traceid = (silly_traceid_t)luaL_checkinteger(L, 2); } traceid = silly_trace_set(traceid); lua_pushinteger(L, (lua_Integer)traceid); @@ -493,7 +493,7 @@ ltraceset(lua_State *L) static int ltraceget(lua_State *L) { - silly_trace_id_t traceid; + silly_traceid_t traceid; traceid = silly_trace_get(); lua_pushinteger(L, (lua_Integer)traceid); return 1; diff --git a/lualib-src/lualib-netpacket.c b/lualib-src/lualib-netpacket.c index 5aa4430..d7f5be6 100644 --- a/lualib-src/lualib-netpacket.c +++ b/lualib-src/lualib-netpacket.c @@ -11,6 +11,7 @@ #include "silly_trace.h" #include "silly_malloc.h" +#define ACK_BIT (1UL << 31) #define DEFAULT_QUEUE_SIZE 2048 #define HASH_SIZE 2048 #define HASH(a) (a % HASH_SIZE) @@ -43,6 +44,8 @@ struct netpacket { struct incomplete *hash[HASH_SIZE]; }; +static session_t session_idx = 0; + static int lcreate(lua_State *L) { @@ -225,14 +228,11 @@ clear_incomplete(lua_State *L, int sid) } static inline const char * -getbuffer(lua_State *L, int *stk, size_t *sz) +getbuffer(lua_State *L, int n, size_t *sz) { - int n = *stk; if (lua_type(L, n) == LUA_TSTRING) { - *stk = n + 1; return lua_tolstring(L, n, sz); } else { - *stk = n + 2; *sz = luaL_checkinteger(L, n + 1); return lua_touserdata(L, n); } @@ -255,143 +255,112 @@ pop_packet(lua_State *L) } } -static int -lpop(lua_State *L) -{ - struct packet *pk = pop_packet(L); - if (pk == NULL) - return 0; - lua_pushinteger(L, pk->fd); - lua_pushlightuserdata(L, pk->buff); - lua_pushinteger(L, pk->size); - return 3; -} +//rpc_cookie {traceid(uint64),cmd(uint32),session(uint32)} -static int -lpack(lua_State *L) -{ - uint8_t *p; - int stk = 1; - size_t size; - const char *str; - str = getbuffer(L, &stk, &size); - if (size > USHRT_MAX) - luaL_error(L, "netpacket.pack data large then:%d\n", USHRT_MAX); - p = silly_malloc(2 + size); - p[0] = (size >> 8) & 0xff; - p[1] = size & 0xff; - memcpy(p + 2, str, size); - lua_pushlightuserdata(L, p); - lua_pushinteger(L, 2 + size); - return 2; -} +#define req_cookie_size (sizeof(silly_traceid_t)+sizeof(cmd_t)+sizeof(session_t)) +#define req_traceid_ref(ptr) (*(silly_traceid_t*)(ptr)) +#define req_cmd_ref(ptr) (*(cmd_t *)(ptr+sizeof(silly_traceid_t))) +#define req_session_ref(ptr) (*(session_t*)(ptr+sizeof(silly_traceid_t)+sizeof(cmd_t))) + +#define ack_cookie_size (sizeof(session_t)) +#define ack_session_ref(ptr) (*(session_t*)(ptr)) static int -lmsgpop(lua_State *L) +lpop(lua_State *L) { int size; char *buf; - cmd_t cmd; + session_t session; struct packet *pk = pop_packet(L); if (pk == NULL) return 0; - size = pk->size - sizeof(cmd_t); + size = pk->size - ack_cookie_size; buf = pk->buff; if (size < 0) return 0; //WARN: pointer cast may not align, can't cross platform - cmd = *(cmd_t *)(buf + size); - lua_pushinteger(L, pk->fd); - lua_pushlightuserdata(L, buf); - lua_pushinteger(L, size); - lua_pushinteger(L, cmd); - return 4; + session = ack_session_ref(buf+size); + if ((session & ACK_BIT) == ACK_BIT) { //rpc ack + lua_pushinteger(L, pk->fd); + lua_pushlightuserdata(L, buf); + lua_pushinteger(L, size); + lua_pushinteger(L, (lua_Integer)(session & ~ACK_BIT)); + lua_pushnil(L); //cmd + lua_pushinteger(L, 0); //traceid + } else { + void *cookie; + size = pk->size - req_cookie_size; + cookie = (void *)(buf + size); + lua_pushinteger(L, pk->fd); + lua_pushlightuserdata(L, buf); + lua_pushinteger(L, size); + lua_pushinteger(L, session); + lua_pushinteger(L, req_cmd_ref(cookie)); + lua_pushinteger(L, (lua_Integer)req_traceid_ref(cookie)); + } + return 6; } static int -lmsgpack(lua_State *L) +lreq(lua_State *L) { + cmd_t cmd; uint8_t *p; const char *str; + void *cookie; size_t size, body; - int cmd, stk = 1; - str = getbuffer(L, &stk, &size); - if (size > (USHRT_MAX - sizeof(cmd_t))) { + session_t session; + silly_traceid_t traceid; + cmd = luaL_checkinteger(L, 1); + traceid = luaL_checkinteger(L, 2); + str = getbuffer(L, 3, &size); + if (size > (USHRT_MAX - req_cookie_size)) { luaL_error(L, "netpacket.pack data large then:%d\n", - USHRT_MAX - sizeof(cmd_t)); + USHRT_MAX - req_cookie_size); + } + session = session_idx++; + if (session >= ACK_BIT) { + session_idx = 0; + session = 0; } - cmd = luaL_checkinteger(L, stk); - body = size + sizeof(cmd_t); + body = size + req_cookie_size; p = silly_malloc(2 + body); p[0] = (body >> 8) & 0xff; p[1] = body & 0xff; memcpy(p + 2, str, size); //WARN: pointer cast may not align, can't cross platform - *(cmd_t *)&p[size+2] = cmd; + cookie = (void *)&p[2 + size]; + req_cmd_ref(cookie) = cmd; + req_session_ref(cookie) = session; + req_traceid_ref(cookie) = traceid; + lua_pushinteger(L, session); lua_pushlightuserdata(L, p); lua_pushinteger(L, 2 + body); - return 2; -} - -struct rpc_cookie { - cmd_t cmd; - session_t session; - silly_trace_id_t traceid; -}; - -static int -lrpcpop(lua_State *L) -{ - int size; - char *buf; - struct rpc_cookie *rpc; - struct packet *pk = pop_packet(L); - if (pk == NULL) - return 0; - size = pk->size - sizeof(struct rpc_cookie); - buf = pk->buff; - if (size < 0) - return 0; - //WARN: pointer cast may not align, can't cross platform - rpc = (struct rpc_cookie *)(buf + size); - lua_pushinteger(L, pk->fd); - lua_pushlightuserdata(L, buf); - lua_pushinteger(L, size); - lua_pushinteger(L, rpc->cmd); - lua_pushinteger(L, rpc->session); - lua_pushinteger(L, (lua_Integer)rpc->traceid); - return 6; + return 3; } static int -lrpcpack(lua_State *L) +lack(lua_State *L) { - cmd_t cmd; uint8_t *p; const char *str; + void *cookie; size_t size, body; - struct rpc_cookie *rpc; - int stk = 1; session_t session; - silly_trace_id_t traceid; - str = getbuffer(L, &stk, &size); - if (size > (USHRT_MAX - sizeof(struct rpc_cookie))) { + session = luaL_checkinteger(L, 1) | ACK_BIT; + str = getbuffer(L, 2, &size); + if (size > (USHRT_MAX - ack_cookie_size)) { luaL_error(L, "netpacket.pack data large then:%d\n", - USHRT_MAX - sizeof(struct rpc_cookie)); + USHRT_MAX - ack_cookie_size); } - cmd = luaL_checkinteger(L, stk); - session = luaL_checkinteger(L, stk+1); - traceid = luaL_checkinteger(L, stk+2); - body = size + sizeof(struct rpc_cookie); + body = size + ack_cookie_size; p = silly_malloc(2 + body); p[0] = (body >> 8) & 0xff; p[1] = body & 0xff; memcpy(p + 2, str, size); //WARN: pointer cast may not align, can't cross platform - rpc = (struct rpc_cookie *)&p[2 + size]; - rpc->cmd = cmd; - rpc->session = session; - rpc->traceid = traceid; + cookie = (void *)&p[2 + size]; + ack_session_ref(cookie) = session; lua_pushlightuserdata(L, p); lua_pushinteger(L, 2 + body); return 2; @@ -480,14 +449,11 @@ int luaopen_core_netpacket(lua_State *L) luaL_Reg tbl[] = { {"create", lcreate}, {"pop", lpop}, - {"pack", lpack}, - {"msgpop", lmsgpop}, - {"msgpack", lmsgpack}, - {"rpcpop", lrpcpop}, - {"rpcpack", lrpcpack}, + {"req", lreq}, + {"ack", lack}, {"clear", lclear}, - {"tostring", ltostring}, {"drop", ldrop}, + {"tostring", ltostring}, {"message", lmessage}, {NULL, NULL}, }; diff --git a/lualib/core/cluster.lua b/lualib/core/cluster.lua new file mode 100644 index 0000000..c926a38 --- /dev/null +++ b/lualib/core/cluster.lua @@ -0,0 +1,300 @@ +local core = require "core" +local mutex = require "core.sync.mutex" +local dns = require "core.dns" +local logger = require "core.logger" +local np = require "core.netpacket" +local type = type +local pairs = pairs +local assert = assert +local tcp_connect = core.tcp_connect +local tcp_send = core.tcp_send +local tcp_close = core.socket_close +local tcp_listen = core.tcp_listen +local pcall = core.pcall +local timeout = core.timeout +local timercancel = core.timercancel +local setmetatable = setmetatable +local NIL = {} + +local mt = { + __gc = function(self) + local fdaddr = self.__fdaddr + for k, _ in pairs(fdaddr) do + if type(k) == "number" then + tcp_close(k) + end + fdaddr[k] = nil + end + end, + __index = function(self, msg) + local waitfor = self.__waitfor + local fdaddr = self.__fdaddr + local marshal = self.__marshal + local callret = self.__callret + local fn = function(fd, body) + if not fdaddr[fd] then + return nil, "closed" + end + local traceid = core.tracepropagate() + local cmd, dat, sz = marshal(msg, body) + local session, body, size = np.req(cmd, traceid, dat, sz) + local ok, err = tcp_send(fd, body, size) + if not ok then + return nil, err + end + local ret = callret[msg] + if ret then + return waitfor(session, ret) + end + end + self[msg] = fn + return fn + end +} + +local function connect_wrapper(self) + local lock = self.__lock + local fdaddr = self.__fdaddr + local connecting = self.__connecting + return function(addr) + local fd = fdaddr[addr] + if fd then + return fd, "connected" + end + connecting[addr] = true + local l = lock:lock(addr) + local fd = fdaddr[addr] + if fd then + return fd, "connected" + end + while true do + local newaddr = addr + local name, port = addr:match("([^:]+):(%d+)") + if dns.isname(name) then + local ip = dns.lookup(name) + if ip then + newaddr = ip .. ":" .. port + else + newaddr = nil + logger.error("[rpc.client] dns lookup fail", name) + end + end + if newaddr then + local fd, errno = tcp_connect(newaddr, self.__event) + if fd then + if connecting[addr] then + connecting[addr] = nil + fdaddr[addr] = fd + fdaddr[fd] = addr + return fd, "ok" + else --already close + tcp_close(fd) + return nil, "active closed" + end + else + logger.error("[rpc.client] connect fail", addr, errno) + end + end + core.sleep(1000) + logger.info("[rpc.client] retry connect:", addr) + end + end +end + +local function listen_wrapper(self) + return function(addr, backlog) + local fd, errno = tcp_listen(addr, self.__event, backlog) + if not fd then + return fd, errno + end + self.__fdaddr[addr] = fd + self.__fdaddr[fd] = addr + return fd, nil + end +end + +local function close_wrapper(self) + return function(addr) + local connecting = self.__connecting + if connecting[addr] then + connecting[addr] = nil + return true, "connecting" + end + local fdaddr = self.__fdaddr + local fd = fdaddr[addr] + if not fd then + return false, "closed" + end + if type(addr) == "string" then + addr, fd = fd, addr + end + fdaddr[fd] = nil + fdaddr[addr] = nil + core.socket_close(fd) + return true, "connected" + end +end + +local function nop() end + +local function init_event(self, conf) + local waitpool = self.__waitpool + local ackcmd = self.__ackcmd + local fdaddr = self.__fdaddr + local callret = self.__callret + local ctx = self.__ctx + local router = assert(conf.router, "router") + local close = assert(conf.close, "close") + local marshal = assert(conf.marshal, "marshal") + local unmarshal = assert(conf.unmarshal, "unmarshal") + local accept = conf.accept or nop + local EVENT = {} + function EVENT.accept(fd, _, addr) + fdaddr[fd] = addr + fdaddr[addr] = fd + local ok, err = pcall(accept, fd, addr) + if not ok then + logger.error("[rpc.server] EVENT.accept", err) + np.clear(ctx, fd) + core.socket_close(fd) + end + end + + function EVENT.close(fd, errno) + local addr = fdaddr[fd] + fdaddr[fd] = nil + fdaddr[addr] = nil + local ok, err = pcall(close, fd, errno) + if not ok then + logger.error("[rpc.server] EVENT.close", err) + end + np.clear(ctx, fd) + end + + function EVENT.data() + local fd, buf, size, session, cmd, traceid = np.pop(ctx) + if not fd then + return + end + local otrace = core.trace(traceid) + core.fork(EVENT.data) + while true do + local dat + --pars + if cmd then + local fn = router[cmd] + if not fn then + np.drop(buf) + logger.error("[rpc.server] no router", cmd) + break + end + local body = unmarshal(cmd, buf, size) + np.drop(buf) + if not body then + logger.error("[rpc.server] decode fail", + session, cmd) + break + end + local ok, res = pcall(fn, body, fd) + if not ok then + logger.error("[rpc.server] call error", res) + break + end + local ackname = callret[cmd] + if not ackname then + break + end + --ack + res = res or NIL + local _, buf, sz = marshal(ackname, res) + tcp_send(fd, np.ack(session, buf, sz)) + else + local cmd = ackcmd[session] + if not cmd then --timeout + np.drop(buf) + logger.warn("[rpc.client] late session", + session) + break + end + local body = unmarshal(cmd, buf, size) + np.drop(buf) + if not body then + logger.error("[rpc.server] decode fail", + session) + break + end + local co = waitpool[session] + waitpool[session] = nil + core.wakeup(co, body) + end + --next + fd, buf, size, session, cmd, traceid = np.pop(ctx) + if not fd then + return + end + core.trace(traceid) + end + core.trace(otrace) + end + return function(type, fd, message, ...) + np.message(ctx, message) + assert(EVENT[type])(fd, ...) + end +end + + +local function waitfor_wrapper(self, expire) + expire = expire or 5000 + local waitpool = self.__waitpool + local ackcmd = self.__ackcmd + local timer_func = function(session) + local co = waitpool[session] + if not co then + logger.error("[rpc.client] timer error session:", session) + return + end + waitpool[session] = nil + ackcmd[session] = "timeout" + core.wakeup(co) + end + return function(session, ack) + local co = core.running() + local timer_id = timeout(expire, timer_func, session) + waitpool[session] = co + ackcmd[session] = ack + local body = core.wait() + if body then + timercancel(timer_id) + end + local cmd = ackcmd[session] + ackcmd[session] = nil + return body, cmd + end +end + +local M = {} +function M.new(conf) + local obj = { + __lock = mutex:new(), + __ctx = np.create(), + __fdaddr = {}, + __waitpool = {}, + __ackcmd = {}, + __waitfor = nil, + __event = nil, + __connecting = {}, + __callret = assert(conf.callret, "callret"), + __marshal = assert(conf.marshal, "marshal"), + connect = nil, + listen = nil, + close = nil, + } + obj.connect = connect_wrapper(obj) + obj.listen = listen_wrapper(obj) + obj.close = close_wrapper(obj) + obj.__event = init_event(obj, conf) + obj.__waitfor = waitfor_wrapper(obj, conf.timeout) + return setmetatable(obj, mt) +end + +return M diff --git a/lualib/core/cluster/msg.lua b/lualib/core/cluster/msg.lua deleted file mode 100644 index 1ce067d..0000000 --- a/lualib/core/cluster/msg.lua +++ /dev/null @@ -1,187 +0,0 @@ -local core = require "core" -local logger = require "core.logger" -local np = require "core.netpacket" -local pairs = pairs -local assert = assert -local type = type -local msg = {} -local msgserver = {} -local msgclient = {} -local queue = np.create() - -local function gc(obj) - if not obj.fd then - return - end - if obj.fd < 0 then - return - end - core.socket_close(obj.fd) - obj.fd = false -end - -local servermt = {__index = msgserver, __gc = gc} -local clientmt = {__index = msgclient, __gc = gc} - -local function event_callback(proto, accept_cb, close_cb, data_cb) - local EVENT = {} - function EVENT.accept(fd, portid, addr) - local ok, err = core.pcall(accept_cb, fd, addr) - if not ok then - logger.error("[msg] EVENT.accept", err) - core.socket_close(fd) - end - end - function EVENT.close(fd, errno) - local ok, err = core.pcall(close_cb, fd, errno) - if not ok then - logger.error("[msg] EVENT.close", err) - end - end - function EVENT.data() - local f, d, sz, cmd = np.msgpop(queue) - if not f then - return - end - core.fork(EVENT.data) - while true do - --parse - local dat, size = proto:unpack(d, sz, true) - np.drop(d) - local obj = proto:decode(cmd, dat, size) - local ok, err = core.pcall(data_cb, f, cmd, obj) - if not ok then - logger.error("[msg] dispatch socket", err) - end - f, d, sz, cmd = np.msgpop(queue) - if not f then - return - end - end - end - return function (type, fd, message, ...) - np.message(queue, message) - assert(EVENT[type])(fd, ...) - end -end - ----server -local function sendmsg(self, fd, cmd, data) - local proto = self.proto - if type(cmd) == "string" then - cmd = proto:tag(cmd) - end - local dat, sz = proto:encode(cmd, data, true) - dat, sz= proto:pack(dat, sz, true) - return core.tcp_send(fd, np.msgpack(dat, sz, cmd)) -end -msgserver.send = sendmsg -msgserver.sendbin = function(self, fd, cmd, bin) - return core.tcp_send(fd, np.msgpack(bin, cmd)) -end -msgserver.multipack = function(self, cmd, dat, n) - local proto = self.proto - if type(cmd) == "string" then - cmd = proto:tag(cmd) - end - local dat, sz = proto:encode(cmd, dat, true) - dat, sz = proto:pack(dat, sz, true) - dat, sz = np.msgpack(dat, sz, cmd) - dat, sz = core.multipack(dat, sz, n) - return dat, sz -end - -msgserver.multicast = function(self, fd, data, sz) - return core.tcp_multicast(fd, data, sz) -end - -function msgserver.stop(self) - gc(self) -end - -function msgserver.close(self, fd) - core.socket_close(fd) -end - ------client -local function wakeupall(self) - local q = self.connectqueue - for k, v in pairs(q) do - core.wakeup(v) - q[k] = nil - end -end - -local function checkconnect(self) - if self.fd and self.fd >= 0 then - return self.fd - end - if not self.fd then --disconnected - self.fd = -1 - local fd = core.tcp_connect(self.addr, self.callback) - if not fd then - self.fd = false - else - self.fd = fd - end - wakeupall(self) - return self.fd - else - local co = core.running() - local t = self.connectqueue - t[#t + 1] = co - core.wait() - return self.fd and self.fd > 0 - end -end - -function msgclient.close(self) - gc(self) -end - -function msgclient.send(self, cmd, data) - local fd = checkconnect(self) - if not fd then - return false - end - return sendmsg(self, fd, cmd, data) -end - -function msg.connect(conf) - local obj = { - fd = false, - callback = false, - addr = conf.addr, - proto = conf.proto, - connectqueue = {}, - } - local close_cb = assert(conf.close, "clientcb close") - local data_cb = assert(conf.data, "clientcb data") - obj.callback = event_callback(conf.proto, nil, close_cb, data_cb) - setmetatable(obj, clientmt) - checkconnect(obj) - return obj -end - -function msg.listen(conf) - local obj = { - fd = false, - callback = false, - addr = conf.addr, - proto = conf.proto, - } - local accept_cb = assert(conf.accept, "servercb accept") - local close_cb = assert(conf.close, "servercb close") - local data_cb = assert(conf.data, "servercb data") - obj.callback = event_callback(conf.proto, accept_cb, close_cb, data_cb) - setmetatable(obj, servermt) - local fd, errno = core.tcp_listen(obj.addr, obj.callback, obj.backlog) - if not fd then - return nil, errno - end - return obj -end - - -return msg - diff --git a/lualib/core/cluster/rpc.lua b/lualib/core/cluster/rpc.lua deleted file mode 100644 index afa36be..0000000 --- a/lualib/core/cluster/rpc.lua +++ /dev/null @@ -1,336 +0,0 @@ -local core = require "core" -local logger = require "core.logger" -local np = require "core.netpacket" -local zproto = require "zproto" -local type = type -local pairs = pairs -local assert = assert -local pack = string.pack -local unpack = string.unpack -local tcp_send = core.tcp_send -local queue = np.create() -local NIL = {} -local rpc = {} - -local function gc(obj) - if not obj.fd then - return - end - local fd = obj.fd - obj.fd = false - if fd < 0 then - return - end - core.socket_close(fd) -end - ------------server -local server = {} -local servermt = {__index = server} - -local function server_listen(self) - local EVENT = {} - local accept = assert(self.accept, "accept") - local close = assert(self.close, "close") - local call = assert(self.call, "call") - local proto = self.proto - function EVENT.accept(fd, portid, addr) - local ok, err = core.pcall(accept, fd, addr) - if not ok then - logger.error("[rpc.server] EVENT.accept", err) - np.clear(queue, fd) - core.socket_close(fd) - end - end - - function EVENT.close(fd, errno) - local ok, err = core.pcall(close, fd, errno) - if not ok then - logger.error("[rpc.server] EVENT.close", err) - end - np.clear(queue, fd) - end - - function EVENT.data() - local fd, buf, size, cmd, session, traceid = np.rpcpop(queue) - if not fd then - return - end - local otrace = core.trace(traceid) - core.fork(EVENT.data) - while true do - local dat - --parse - dat, size = proto:unpack(buf, size, true) - np.drop(buf) - local body = proto:decode(cmd, dat, size) - if not body then - logger.error("[rpc.server] decode fail", - session, cmd) - return - end - local ok, ret, res = core.pcall(call, body, cmd, fd) - if not ok then - logger.error("[rpc.server] call error", ret) - return - end - if not ret then - return - end - --ack - res = res or NIL - if type(ret) == "string" then - ret = proto:tag(ret) - end - local bodydat, sz = proto:encode(ret, res, true) - bodydat, sz = proto:pack(bodydat, sz, true) - tcp_send(fd, np.rpcpack(bodydat, sz, ret, session, traceid)) - --next - fd, buf, size, cmd, session, traceid = np.rpcpop(queue) - if not fd then - return - end - core.trace(traceid) - end - core.trace(otrace) - end - local callback = function(type, fd, message, ...) - np.message(queue, message) - assert(EVENT[type])(fd, ...) - end - local fd, errno = core.tcp_listen(self.addr, callback, self.backlog) - self.fd = fd - return fd, errno -end - -function server.close(self) - gc(self) -end - --------client -local client = {} -local clientmt = {__index = client, __gc = gc} - -local function wakeup_all_calling(self) - local waitpool = self.waitpool - local ackcmd = self.ackcmd - for session, co in pairs(waitpool) do - waitpool[session] = nil - logger.info("[rpc.client] wakeupall session", session) - ackcmd[co] = "closed" - core.wakeup(co) - end -end - -local function wakeup_all_connect(self) - local q = self.connectqueue - for k, v in pairs(q) do - core.wakeup(v) - q[k] = nil - end -end - -local function doconnect(self) - local EVENT = {} - local addr = self.__addr - local close = self.__close - local proto = self.__proto - local ackcmd = self.ackcmd - local waitpool = self.waitpool - function EVENT.close(fd, errno) - if close then - local ok, err = core.pcall(close, fd, errno) - if not ok then - logger.info("[rpc.client] EVENT.close", err) - end - end - self.fd = nil - np.clear(queue, fd) - end - - function EVENT.data() - local fd, d, sz, cmd, session, _ = np.rpcpop(queue) - if not fd then - return - end - core.fork(EVENT.data) - while true do - local str - str, sz = proto:unpack(d, sz, true) - np.drop(d) - local body = proto:decode(cmd, str, sz) - if not body then - logger.error("[rpc.client] decode fail", - session, cmd) - return - end - --ack - local co = waitpool[session] - if not co then --timeout - logger.warn("[rpc.client] late session", - session, cmd) - return - end - waitpool[session] = nil - ackcmd[co] = cmd - core.wakeup(co, body) - --next - fd, d, sz, cmd, session, _ = np.rpcpop(queue) - if not fd then - break - end - end - end - - local callback = function(type, fd, message, ...) - np.message(queue, message) - assert(EVENT[type])(fd, ...) - end - return core.tcp_connect(addr, callback) -end - ---return true/false -local function checkconnect(self) - if self.fd and self.fd >= 0 then - return self.fd - end - if self.closed then - return false - end - if not self.fd then --disconnected - self.fd = -1 - local fd = doconnect(self) - if self.closed then - if fd then - core.socket_close(fd) - fd = nil - end - end - if not fd then - logger.error("[rpc.client] connect", self.__addr, "fail") - self.fd = false - else - self.fd = fd - end - wakeup_all_connect(self) - return self.fd - else - local co = core.running() - local t = self.connectqueue - t[#t + 1] = co - core.wait() - return self.fd and self.fd > 0 - end -end - -local timeout = core.timeout -local timercancel = core.timercancel -local function waitfor(self, expire) - local waitpool = self.waitpool - local ackcmd = self.ackcmd - local timer_func = function(session) - if self.closed then - return - end - local co = waitpool[session] - if not co then - logger.error("[rpc.client] timer error session:", session) - return - end - waitpool[session] = nil - ackcmd[co] = "timeout" - core.wakeup(co) - end - return function(session) - local co = core.running() - local timer = timeout(expire, timer_func, session) - waitpool[session] = co - local body = core.wait() - if body then - timercancel(timer) - end - local cmd = ackcmd[co] - ackcmd[co] = nil - return body, cmd - end -end - -local function send_request(self, cmd, body) - local ok = checkconnect(self) - if not ok then - return false, "closed" - end - local proto = self.__proto - if type(cmd) == "string" then - cmd = proto:tag(cmd) - end - local session = core.genid() - local traceid = core.tracepropagate() - local body, sz = proto:encode(cmd, body, true) - body, sz = proto:pack(body, sz, true) - local ok = tcp_send(self.fd, np.rpcpack(body, sz, cmd, session, traceid)) - if not ok then - return false, "send fail" - end - return true, session -end - -client.send = send_request - -function client.call(self, cmd, body) - local ok, session = send_request(self, cmd, body) - if not ok then - return false, session - end - return self.waitfor(session) -end - -function client.close(self) - if self.closed then - return - end - gc(self) - self.closed = true - wakeup_all_connect(self) - wakeup_all_calling(self) -end - ------rpc -function rpc.connect(config) - local totalwheel = math.floor((config.timeout + 999) / 1000) - local obj = { - fd = false, --false disconnected, -1 conncting, >=0 conncted - closed = false, - connectqueue = {}, - waitpool = {}, - ackcmd = {}, - waitfor = nil, - __addr = config.addr, - __proto = config.proto, - __close = config.close, - } - obj.waitfor = waitfor(obj, config.timeout) - setmetatable(obj, clientmt) - checkconnect(obj) - return obj -end - -function rpc.listen(config) - local obj = { - addr = config.addr, - backlog = config.backlog, - proto = config.proto, - accept = config.accept, - close = config.close, - call = config.call, - } - setmetatable(obj, servermt) - local ok, errno = server_listen(obj) - if not ok then - return nil, errno - end - return obj -end - -return rpc - diff --git a/lualib/core/logger.lua b/lualib/core/logger.lua index 6505351..a6e0c2c 100644 --- a/lualib/core/logger.lua +++ b/lualib/core/logger.lua @@ -44,4 +44,3 @@ function logger.setlevel(level) end return logger - diff --git a/lualib/core/sync/mutex.lua b/lualib/core/sync/mutex.lua index de9b49b..2c01f66 100644 --- a/lualib/core/sync/mutex.lua +++ b/lualib/core/sync/mutex.lua @@ -40,7 +40,7 @@ local proxymt = { end } -function M.new() +function M:new() return setmetatable({ lockobj = {}, }, mt) diff --git a/silly-src/silly_log.c b/silly-src/silly_log.c index 7e62c02..7dd13ed 100644 --- a/silly-src/silly_log.c +++ b/silly-src/silly_log.c @@ -18,7 +18,7 @@ static __thread struct { char *term; time_t sec; time_t msec; - silly_trace_id_t traceid; + silly_traceid_t traceid; } head_cache = { "", NULL, @@ -78,7 +78,7 @@ fmttime() struct tm tm; uint64_t now = silly_timer_now(); time_t sec = now / 1000; - silly_trace_id_t traceid = silly_trace_get(); + silly_traceid_t traceid = silly_trace_get(); int build_step; if (head_cache.sstr == NULL) { build_step = BUILD_SEC; diff --git a/silly-src/silly_trace.c b/silly-src/silly_trace.c index 9aa3f5f..bb60281 100644 --- a/silly-src/silly_trace.c +++ b/silly-src/silly_trace.c @@ -6,47 +6,47 @@ #include "silly_timer.h" #include "silly_trace.h" -static silly_trace_span_t spanid; +static silly_tracespan_t spanid; static uint16_t seq_idx = 0; //63~48, 47~32, 31~16, 15~0 //spanid(16bit),time(16bit),seq(16bit),spanid(16bit) -static __thread silly_trace_id_t trace_ctx = 0; +static __thread silly_traceid_t trace_ctx = 0; void silly_trace_init() { - silly_trace_span((silly_trace_span_t)getpid()); + silly_trace_span((silly_tracespan_t)getpid()); } void -silly_trace_span(silly_trace_span_t id) +silly_trace_span(silly_tracespan_t id) { spanid = id; } -silly_trace_id_t -silly_trace_set(silly_trace_id_t id) +silly_traceid_t +silly_trace_set(silly_traceid_t id) { - silly_trace_id_t old = trace_ctx; + silly_traceid_t old = trace_ctx; trace_ctx = id; return old; } -silly_trace_id_t +silly_traceid_t silly_trace_get() { return trace_ctx; } -silly_trace_id_t +silly_traceid_t silly_trace_new() { if (trace_ctx > 0 ) { - return (trace_ctx & ~((silly_trace_id_t)0xFF)) | (uint64_t)spanid; + return (trace_ctx & ~((silly_traceid_t)0xFF)) | (uint64_t)spanid; } uint16_t time = (uint16_t)(silly_timer_nowsec()); uint16_t seq = atomic_add_return(&seq_idx, 1); - silly_trace_id_t id = (uint64_t)spanid << 48 | + silly_traceid_t id = (uint64_t)spanid << 48 | (uint64_t)time << 32 | (uint64_t)seq << 16 | (uint64_t)spanid; diff --git a/silly-src/silly_trace.h b/silly-src/silly_trace.h index 5260396..f9c097b 100644 --- a/silly-src/silly_trace.h +++ b/silly-src/silly_trace.h @@ -3,15 +3,15 @@ #include -typedef uint16_t silly_trace_span_t; -typedef uint64_t silly_trace_id_t; +typedef uint16_t silly_tracespan_t; +typedef uint64_t silly_traceid_t; void silly_trace_init(); -void silly_trace_span(silly_trace_span_t id); -silly_trace_id_t silly_trace_set(silly_trace_id_t id); -silly_trace_id_t silly_trace_get(); -silly_trace_id_t silly_trace_new(); -silly_trace_id_t silly_trace_propagate(); +void silly_trace_span(silly_tracespan_t id); +silly_traceid_t silly_trace_set(silly_traceid_t id); +silly_traceid_t silly_trace_get(); +silly_traceid_t silly_trace_new(); +silly_traceid_t silly_trace_propagate(); #endif diff --git a/test/testrpc.lua b/test/testrpc.lua index 15ef633..2a92f6e 100644 --- a/test/testrpc.lua +++ b/test/testrpc.lua @@ -1,6 +1,6 @@ local core = require "core" local waitgroup = require "core.sync.waitgroup" -local rpc = require "core.cluster.rpc" +local cluster = require "core.cluster" local crypto = require "core.crypto" local testaux = require "test.testaux" local zproto = require "zproto" @@ -16,40 +16,83 @@ bar 0xfe { } ]] +print(type(logic)) +assert(logic) + local function case_one(msg, cmd, fd) - if cmd == 0xff then - cmd = 0xfe - else - cmd = 0xff - end - return cmd, msg + return msg end local function case_two(msg, cmd, fd) core.sleep(100) - return cmd, msg + return msg end local function case_three(msg, cmd, fd) + core.sleep(2000) +end + +local function unmarshal(cmd, buf, size) + local dat, size = logic:unpack(buf, size, true) + local body = logic:decode(cmd, dat, size) + return body end +local function marshal(cmd, body) + if type(cmd) == "string" then + cmd = logic:tag(cmd) + end + local dat, size = logic:encode(cmd, body, true) + local buf, size = logic:pack(dat, size, true) + return cmd, buf, size +end + + local case = case_one +local accept_fd +local accept_addr +local router = setmetatable({}, {__index = function(t, k) + local fn = function(msg, fd) + return case(msg, k, fd) + end + t[k] = fn + return fn +end}) + +local callret = { + ["foo"] = "bar", + [0xff] = "bar", + ["bar"] = "foo", + [0xfe] = "foo", +} -local server = rpc.listen { - addr = ":8989", - proto = logic, + +local server = cluster.new { + timeout = 1000, + marshal = marshal, + unmarshal = unmarshal, + callret = callret, + router = router, accept = function(fd, addr) + accept_fd = fd + accept_addr = addr end, close = function(fd, errno) end, - - call = function(msg, cmd, fd) - return case(msg, cmd, fd) - end } -local client +server.listen(":8989") +local client_fd +local client = cluster.new { + timeout = 1000, + marshal = marshal, + unmarshal = unmarshal, + callret = callret, + router = router, + close = function(fd, errno) + end, +} local function request(fd, index, count, cmd) return function() @@ -59,9 +102,9 @@ local function request(fd, index, count, cmd) age = index, rand = crypto.randomkey(8), } - local body, ack = client:call(cmd, test) + local body, ack = client[cmd](fd, test) testaux.assertneq(body, nil, "rpc timeout") - testaux.asserteq(test.rand, body.rand, "rpc match request/response") + testaux.asserteq(test.rand, body and body.rand, "rpc match request/response") end end end @@ -74,7 +117,7 @@ local function timeout(fd, index, count, cmd) age = index, rand = crypto.randomkey(8), } - local body, ack = client:call(cmd, test) + local body, ack = client[cmd](fd, test) testaux.asserteq(body, nil, "rpc timeout, body is nil") testaux.asserteq(ack, "timeout", "rpc timeout, ack is timeout") end @@ -84,13 +127,9 @@ end local function client_part() - client = rpc.connect { - addr = "127.0.0.1:8989", - proto = logic, - timeout = 1000, - close = function(fd, errno) - end, - } + local err + client_fd, err = client.connect("127.0.0.1:8989") + print("connect", client_fd, err) local wg = waitgroup:create() case = case_one for i = 1, 2 do @@ -100,26 +139,42 @@ local function client_part() else cmd = "bar" end - wg:fork(request(client, i, 5, cmd)) + wg:fork(request(client_fd, i, 5, cmd)) end wg:wait() print("case one finish") case = case_two for i = 1, 20 do - wg:fork(request(client, i, 50, "foo")) + wg:fork(request(client_fd, i, 50, "foo")) core.sleep(100) end wg:wait() print("case two finish") case = case_three for i = 1, 20 do - wg:fork(timeout(client, i, 2, "foo")) + wg:fork(timeout(client_fd, i, 2, "foo")) core.sleep(10) end wg:wait() print("case three finish") end +local function server_part() + case = case_one + local req = { + name = "hello", + age = 1, + rand = crypto.randomkey(8), + } + local ack, _ = server.foo(accept_fd, req) + testaux.assertneq(ack, nil, "rpc timeout") + testaux.asserteq(req.rand, ack and ack.rand, "rpc match request/response") +end + client_part() -client:close() -server:close() \ No newline at end of file +server_part() +client.close("127.0.0.1:8989") +server.close(":8989") +server.close(accept_addr) +testaux.asserteq(next(client.__fdaddr), nil, "client fdaddr empty") +testaux.asserteq(next(server.__fdaddr), nil, "client fdaddr empty")