From 5e3199098d44d2c10265e40a0ff7b39ef7819615 Mon Sep 17 00:00:00 2001 From: Adam Harrison Date: Fri, 17 Jan 2025 13:22:47 -0500 Subject: [PATCH] Fixed chunked transfer encoding. --- src/lpm.c | 259 +++++++++++++++++++++++++++------------------------- src/lpm.lua | 6 +- 2 files changed, 140 insertions(+), 125 deletions(-) diff --git a/src/lpm.c b/src/lpm.c index b95e33a..5c175f4 100644 --- a/src/lpm.c +++ b/src/lpm.c @@ -1214,7 +1214,6 @@ static int lpm_extract(lua_State* L) { case MTAR_TGLP: { has_ext_before = 1; int read_size = imin(h.size, (int)sizeof(before_h.linkname)); - fprintf(stderr, "READ SIZE: %d\n", read_size); if (mtar_read_data(&tar, before_h.linkname, read_size) != MTAR_ESUCCESS) { mtar_close(&tar); return luaL_error(L, "Error while reading GNU extended: %s", strerror(errno)); @@ -1512,144 +1511,160 @@ static int lpm_extract(lua_State* L) { lua_rawgeti(L, LUA_REGISTRYINDEX, ctx); get_context_t* context = (get_context_t*)lua_touserdata(L, -1); lua_pop(L,1); - switch (context->state) { - case STATE_HANDSHAKE: { - int status = mbedtls_ssl_handshake(&context->ssl); - if (status == MBEDTLS_ERR_SSL_WANT_READ || status == MBEDTLS_ERR_SSL_WANT_WRITE) - return lua_yieldk(L, 0, ctx, lpm_getk); - if ( - lpm_get_error(context, status, "can't handshake") || - lpm_get_error(context, mbedtls_ssl_get_verify_result(&context->ssl), "can't verify result") - ) - goto cleanup; - context->state = STATE_SEND; - } - case STATE_SEND: { - context->buffer_length = snprintf(context->buffer, sizeof(context->buffer), "GET %s HTTP/1.1\r\nHost: %s\r\nConnection: close\r\n\r\n", context->rest, context->hostname); - int length = lpm_socket_write(context, context->buffer_length); - if (length < context->buffer_length && lpm_get_error(context, length, "can't write to socket")) - goto cleanup; - context->buffer_length = 0; - context->buffer[0] = 0; - context->state = STATE_RECV_HEADER; - } - case STATE_RECV_HEADER: { - const char* header_end; - while (1) { - header_end = strstr(context->buffer, "\r\n\r\n"); - if (!header_end && context->buffer_length >= sizeof(context->buffer) - 1 && lpm_set_error(context, "response header buffer length exceeded")) + int is_main_thread = lua_is_main_thread(L); + while (1) { + switch (context->state) { + case STATE_HANDSHAKE: { + int status = mbedtls_ssl_handshake(&context->ssl); + if (status == MBEDTLS_ERR_SSL_WANT_READ || status == MBEDTLS_ERR_SSL_WANT_WRITE) { + if (is_main_thread) + break; + return lua_yieldk(L, 0, ctx, lpm_getk); + } + if ( + lpm_get_error(context, status, "can't handshake") || + lpm_get_error(context, mbedtls_ssl_get_verify_result(&context->ssl), "can't verify result") + ) goto cleanup; - if (!header_end) { - int length = lpm_socket_read(context, -1); - if (length < 0 && lpm_get_error(context, length, "can't read from socket")) - goto cleanup; - if (length == 0) - return lua_yieldk(L, 0, ctx, lpm_getk); - } else { - header_end += 4; - const char* protocol_end = strnstr_local(context->buffer, " ", context->buffer_length); - int code = atoi(protocol_end + 1); - if (code != 200) { - if (code >= 301 && code <= 303) { - const char* location = get_header(context->buffer, "location", &context->buffer_length); - if (location) { - lua_pushnil(L); - lua_newtable(L); - lua_pushlstring(L, location, context->buffer_length); - lua_setfield(L, -2, "location"); - } else - lpm_set_error(context, "received invalid %d-response", code); - } else - lpm_set_error(context, "received non 200-response of %d", code); - goto report; - } - const char* transfer_encoding = get_header(context->buffer, "transfer-encoding", NULL); - context->chunked = transfer_encoding && strncmp(transfer_encoding, "chunked", 7) == 0 ? 1 : 0; - const char* content_length_value = get_header(context->buffer, "content-length", NULL); - context->content_length = content_length_value ? atoi(content_length_value) : -1; - context->buffer_length -= (header_end - context->buffer); - if (context->buffer_length > 0) - memmove(context->buffer, header_end, context->buffer_length); - context->chunk_length = !context->chunked && context->content_length == -1 ? INT_MAX : context->content_length; - context->state = STATE_RECV_BODY; - break; - } + context->state = STATE_SEND; } - } - case STATE_RECV_BODY: { - while (1) { - // If we have an unknown amount of chunk bytes to be fetched, determine the size of the next chunk. - while (context->chunk_length == -1) { - char* newline = (char*)strnstr_local(context->buffer, "\r\n", context->buffer_length); - if (newline) { - *newline = '\0'; - if ((sscanf(context->buffer, "%x", &context->chunk_length) != 1 && lpm_set_error(context, "error retrieving chunk length"))) - goto cleanup; - else if (context->chunk_length == 0) - goto finish; - context->buffer_length -= (newline + 2 - context->buffer); - if (context->buffer_length > 0) - memmove(context->buffer, newline + 2, context->buffer_length); - } else if (context->buffer_length >= sizeof(context->buffer) && lpm_set_error(context, "can't find chunk length")) { + case STATE_SEND: { + context->buffer_length = snprintf(context->buffer, sizeof(context->buffer), "GET %s HTTP/1.1\r\nHost: %s\r\nConnection: close\r\n\r\n", context->rest, context->hostname); + int length = lpm_socket_write(context, context->buffer_length); + if (length < context->buffer_length && lpm_get_error(context, length, "can't write to socket")) + goto cleanup; + context->buffer_length = 0; + context->buffer[0] = 0; + context->state = STATE_RECV_HEADER; + } + case STATE_RECV_HEADER: { + const char* header_end; + while (1) { + header_end = strstr(context->buffer, "\r\n\r\n"); + if (!header_end && context->buffer_length >= sizeof(context->buffer) - 1 && lpm_set_error(context, "response header buffer length exceeded")) goto cleanup; - } else { + if (!header_end) { int length = lpm_socket_read(context, -1); - if ((length <= 0 || (context->is_ssl && length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)) && lpm_get_error(context, length, "error retrieving full repsonse")) + if (length < 0 && lpm_get_error(context, length, "can't read from socket")) goto cleanup; - if (length == 0) + if (length == 0) { + if (is_main_thread) + break; return lua_yieldk(L, 0, ctx, lpm_getk); + } + } else { + header_end += 4; + const char* protocol_end = strnstr_local(context->buffer, " ", context->buffer_length); + int code = atoi(protocol_end + 1); + if (code != 200) { + if (code >= 301 && code <= 303) { + const char* location = get_header(context->buffer, "location", &context->buffer_length); + if (location) { + lua_pushnil(L); + lua_newtable(L); + lua_pushlstring(L, location, context->buffer_length); + lua_setfield(L, -2, "location"); + } else + lpm_set_error(context, "received invalid %d-response", code); + } else + lpm_set_error(context, "received non 200-response of %d", code); + goto report; + } + const char* transfer_encoding = get_header(context->buffer, "transfer-encoding", NULL); + context->chunked = transfer_encoding && strncmp(transfer_encoding, "chunked", 7) == 0 ? 1 : 0; + const char* content_length_value = get_header(context->buffer, "content-length", NULL); + context->content_length = content_length_value ? atoi(content_length_value) : -1; + context->buffer_length -= (header_end - context->buffer); + if (context->buffer_length > 0) + memmove(context->buffer, header_end, context->buffer_length); + context->chunk_length = !context->chunked && context->content_length == -1 ? INT_MAX : context->content_length; + context->state = STATE_RECV_BODY; + break; } } - if (context->buffer_length > 0) { - int to_write = imin(context->chunk_length - context->chunk_written, context->buffer_length); - if (to_write > 0) { - context->total_downloaded += to_write; - context->chunk_written += to_write; - if (context->callback_function) { - lua_rawgeti(L, LUA_REGISTRYINDEX, context->callback_function); - lua_pushinteger(L, context->total_downloaded); - if (context->content_length == -1) - lua_pushnil(L); - else - lua_pushinteger(L, context->content_length); - lua_call(L, 2, 0); + } + case STATE_RECV_BODY: { + while (1) { + // If we have an unknown amount of chunk bytes to be fetched, determine the size of the next chunk. + while (context->chunk_length == -1) { + char* newline = (char*)strnstr_local(context->buffer, "\r\n", context->buffer_length); + if (newline) { + *newline = '\0'; + if ((sscanf(context->buffer, "%x", &context->chunk_length) != 1 && lpm_set_error(context, "error retrieving chunk length"))) + goto cleanup; + else if (context->chunk_length == 0) + goto finish; + context->chunk_written = 0; + context->buffer_length -= (newline + 2 - context->buffer); + if (context->buffer_length > 0) + memmove(context->buffer, newline + 2, context->buffer_length); + } else if (context->buffer_length >= sizeof(context->buffer) && lpm_set_error(context, "can't find chunk length")) { + goto cleanup; + } else { + int length = lpm_socket_read(context, -1); + if ((length <= 0 || (context->is_ssl && length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)) && lpm_get_error(context, length, "error retrieving full repsonse")) + goto cleanup; + if (length == 0) { + if (is_main_thread) + break; + return lua_yieldk(L, 0, ctx, lpm_getk); + } } - if (context->file) - fwrite(context->buffer, sizeof(char), to_write, context->file); - else { - lua_rawgeti(L, LUA_REGISTRYINDEX, context->lua_buffer); - lua_pushlstring(L, context->buffer, to_write); - lua_rawseti(L, -2, lua_rawlen(L, -2) + 1); - lua_pop(L, 1); + } + if (context->buffer_length > 0) { + int to_write = imin(context->chunk_length - context->chunk_written, context->buffer_length); + if (to_write > 0) { + context->total_downloaded += to_write; + context->chunk_written += to_write; + if (context->callback_function) { + lua_rawgeti(L, LUA_REGISTRYINDEX, context->callback_function); + lua_pushinteger(L, context->total_downloaded); + if (context->content_length == -1) + lua_pushnil(L); + else + lua_pushinteger(L, context->content_length); + lua_call(L, 2, 0); + } + if (context->file) + fwrite(context->buffer, sizeof(char), to_write, context->file); + else { + lua_rawgeti(L, LUA_REGISTRYINDEX, context->lua_buffer); + lua_pushlstring(L, context->buffer, to_write); + lua_rawseti(L, -2, lua_rawlen(L, -2) + 1); + lua_pop(L, 1); + } + context->buffer_length -= to_write; + if (context->buffer_length > 0) + memmove(context->buffer, &context->buffer[to_write], context->buffer_length); + } + if (context->chunk_written == context->chunk_length) { + if (!context->chunked) + goto finish; + if (context->buffer_length >= 2) { + if (!strnstr_local(context->buffer, "\r\n", 2) && lpm_set_error(context, "invalid end to chunk")) + goto cleanup; + memmove(context->buffer, &context->buffer[2], context->buffer_length - 2); + context->buffer_length -= 2; + context->chunk_length = -1; + } } - context->buffer_length -= to_write; - if (context->buffer_length > 0) - memmove(context->buffer, &context->buffer[to_write], context->buffer_length); } - if (context->chunk_written == context->chunk_length) { - if (!context->chunked) + if (context->chunk_length > 0) { + int length = lpm_socket_read(context, imin(sizeof(context->buffer) - context->buffer_length, context->chunk_length - context->chunk_written + (context->chunked ? 2 : 0))); + if ((!context->is_ssl && length == 0) || (context->is_ssl && context->content_length == -1 && length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)) goto finish; - if (context->buffer_length >= 2) { - if (!strnstr_local(context->buffer, "\r\n", 2) && lpm_set_error(context, "invalid end to chunk")) - goto cleanup; - memmove(context->buffer, &context->buffer[2], context->buffer_length - 2); - context->buffer_length -= 2; - context->chunk_length = -1; + if (length < 0 && lpm_get_error(context, length, "error retrieving full chunk")) + goto cleanup; + if (length == 0) { + if (is_main_thread) + break; + return lua_yieldk(L, 0, ctx, lpm_getk); } } } - if (context->chunk_length > 0) { - int length = lpm_socket_read(context, imin(sizeof(context->buffer) - context->buffer_length, context->chunk_length - context->chunk_written + (context->chunked ? 2 : 0))); - if ((!context->is_ssl && length == 0) || (context->is_ssl && context->content_length == -1 && length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)) - goto finish; - if (length < 0 && lpm_get_error(context, length, "error retrieving full chunk")) - goto cleanup; - if (length == 0) - return lua_yieldk(L, 0, ctx, lpm_getk); - } } + default: break; } - default: break; } finish: if (context->file) { diff --git a/src/lpm.lua b/src/lpm.lua index d12ce42..8d8f85f 100644 --- a/src/lpm.lua +++ b/src/lpm.lua @@ -509,11 +509,11 @@ function common.args(arguments, options) for k,v in pairs(arguments) do if math.type(k) ~= "integer" then args[k] = v end end while i <= #arguments do local s,e, option, value = arguments[i]:find("%-%-([^=]+)=?(.*)") - local option_name = s and option:gsub("^no%-", "") + local option_name = s and (options[option] and option or option:gsub("^no%-", "")) if options[option_name] then local flag_type = options[option_name] if flag_type == "flag" then - args[option] = not option:find("^no-") and true or false + args[option] = (option_name == option or not option:find("^no-")) and true or false elseif flag_type == "string" or flag_type == "number" or flag_type == "array" then if not value or value == "" then if i == #arguments then error("option " .. option .. " requires a " .. flag_type) end @@ -2584,7 +2584,7 @@ Flags have the following effects: for the system lite-xl. --binary=path Sets the lite-xl binary path for the system lite-xl. --verbose Spits out more information, including intermediate - steps to install and whatnot. + steps to install and whatnot. --quiet Outputs nothing but explicit responses. --mod-version=version Sets the mod version of lite-xl to install addons. Can be set to "any", which will retrieve the latest