Skip to content

Commit

Permalink
Fixed chunked transfer encoding.
Browse files Browse the repository at this point in the history
  • Loading branch information
adamharrison committed Jan 17, 2025
1 parent d0b6849 commit 5e31990
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 125 deletions.
259 changes: 137 additions & 122 deletions src/lpm.c
Original file line number Diff line number Diff line change
Expand Up @@ -1214,7 +1214,6 @@ static int lpm_extract(lua_State* L) {
case MTAR_TGLP: {
has_ext_before = 1;
int read_size = imin(h.size, (int)sizeof(before_h.linkname));
fprintf(stderr, "READ SIZE: %d\n", read_size);
if (mtar_read_data(&tar, before_h.linkname, read_size) != MTAR_ESUCCESS) {
mtar_close(&tar);
return luaL_error(L, "Error while reading GNU extended: %s", strerror(errno));
Expand Down Expand Up @@ -1512,144 +1511,160 @@ static int lpm_extract(lua_State* L) {
lua_rawgeti(L, LUA_REGISTRYINDEX, ctx);
get_context_t* context = (get_context_t*)lua_touserdata(L, -1);
lua_pop(L,1);
switch (context->state) {
case STATE_HANDSHAKE: {
int status = mbedtls_ssl_handshake(&context->ssl);
if (status == MBEDTLS_ERR_SSL_WANT_READ || status == MBEDTLS_ERR_SSL_WANT_WRITE)
return lua_yieldk(L, 0, ctx, lpm_getk);
if (
lpm_get_error(context, status, "can't handshake") ||
lpm_get_error(context, mbedtls_ssl_get_verify_result(&context->ssl), "can't verify result")
)
goto cleanup;
context->state = STATE_SEND;
}
case STATE_SEND: {
context->buffer_length = snprintf(context->buffer, sizeof(context->buffer), "GET %s HTTP/1.1\r\nHost: %s\r\nConnection: close\r\n\r\n", context->rest, context->hostname);
int length = lpm_socket_write(context, context->buffer_length);
if (length < context->buffer_length && lpm_get_error(context, length, "can't write to socket"))
goto cleanup;
context->buffer_length = 0;
context->buffer[0] = 0;
context->state = STATE_RECV_HEADER;
}
case STATE_RECV_HEADER: {
const char* header_end;
while (1) {
header_end = strstr(context->buffer, "\r\n\r\n");
if (!header_end && context->buffer_length >= sizeof(context->buffer) - 1 && lpm_set_error(context, "response header buffer length exceeded"))
int is_main_thread = lua_is_main_thread(L);
while (1) {
switch (context->state) {
case STATE_HANDSHAKE: {
int status = mbedtls_ssl_handshake(&context->ssl);
if (status == MBEDTLS_ERR_SSL_WANT_READ || status == MBEDTLS_ERR_SSL_WANT_WRITE) {
if (is_main_thread)
break;
return lua_yieldk(L, 0, ctx, lpm_getk);
}
if (
lpm_get_error(context, status, "can't handshake") ||
lpm_get_error(context, mbedtls_ssl_get_verify_result(&context->ssl), "can't verify result")
)
goto cleanup;
if (!header_end) {
int length = lpm_socket_read(context, -1);
if (length < 0 && lpm_get_error(context, length, "can't read from socket"))
goto cleanup;
if (length == 0)
return lua_yieldk(L, 0, ctx, lpm_getk);
} else {
header_end += 4;
const char* protocol_end = strnstr_local(context->buffer, " ", context->buffer_length);
int code = atoi(protocol_end + 1);
if (code != 200) {
if (code >= 301 && code <= 303) {
const char* location = get_header(context->buffer, "location", &context->buffer_length);
if (location) {
lua_pushnil(L);
lua_newtable(L);
lua_pushlstring(L, location, context->buffer_length);
lua_setfield(L, -2, "location");
} else
lpm_set_error(context, "received invalid %d-response", code);
} else
lpm_set_error(context, "received non 200-response of %d", code);
goto report;
}
const char* transfer_encoding = get_header(context->buffer, "transfer-encoding", NULL);
context->chunked = transfer_encoding && strncmp(transfer_encoding, "chunked", 7) == 0 ? 1 : 0;
const char* content_length_value = get_header(context->buffer, "content-length", NULL);
context->content_length = content_length_value ? atoi(content_length_value) : -1;
context->buffer_length -= (header_end - context->buffer);
if (context->buffer_length > 0)
memmove(context->buffer, header_end, context->buffer_length);
context->chunk_length = !context->chunked && context->content_length == -1 ? INT_MAX : context->content_length;
context->state = STATE_RECV_BODY;
break;
}
context->state = STATE_SEND;
}
}
case STATE_RECV_BODY: {
while (1) {
// If we have an unknown amount of chunk bytes to be fetched, determine the size of the next chunk.
while (context->chunk_length == -1) {
char* newline = (char*)strnstr_local(context->buffer, "\r\n", context->buffer_length);
if (newline) {
*newline = '\0';
if ((sscanf(context->buffer, "%x", &context->chunk_length) != 1 && lpm_set_error(context, "error retrieving chunk length")))
goto cleanup;
else if (context->chunk_length == 0)
goto finish;
context->buffer_length -= (newline + 2 - context->buffer);
if (context->buffer_length > 0)
memmove(context->buffer, newline + 2, context->buffer_length);
} else if (context->buffer_length >= sizeof(context->buffer) && lpm_set_error(context, "can't find chunk length")) {
case STATE_SEND: {
context->buffer_length = snprintf(context->buffer, sizeof(context->buffer), "GET %s HTTP/1.1\r\nHost: %s\r\nConnection: close\r\n\r\n", context->rest, context->hostname);
int length = lpm_socket_write(context, context->buffer_length);
if (length < context->buffer_length && lpm_get_error(context, length, "can't write to socket"))
goto cleanup;
context->buffer_length = 0;
context->buffer[0] = 0;
context->state = STATE_RECV_HEADER;
}
case STATE_RECV_HEADER: {
const char* header_end;
while (1) {
header_end = strstr(context->buffer, "\r\n\r\n");
if (!header_end && context->buffer_length >= sizeof(context->buffer) - 1 && lpm_set_error(context, "response header buffer length exceeded"))
goto cleanup;
} else {
if (!header_end) {
int length = lpm_socket_read(context, -1);
if ((length <= 0 || (context->is_ssl && length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)) && lpm_get_error(context, length, "error retrieving full repsonse"))
if (length < 0 && lpm_get_error(context, length, "can't read from socket"))
goto cleanup;
if (length == 0)
if (length == 0) {
if (is_main_thread)
break;
return lua_yieldk(L, 0, ctx, lpm_getk);
}
} else {
header_end += 4;
const char* protocol_end = strnstr_local(context->buffer, " ", context->buffer_length);
int code = atoi(protocol_end + 1);
if (code != 200) {
if (code >= 301 && code <= 303) {
const char* location = get_header(context->buffer, "location", &context->buffer_length);
if (location) {
lua_pushnil(L);
lua_newtable(L);
lua_pushlstring(L, location, context->buffer_length);
lua_setfield(L, -2, "location");
} else
lpm_set_error(context, "received invalid %d-response", code);
} else
lpm_set_error(context, "received non 200-response of %d", code);
goto report;
}
const char* transfer_encoding = get_header(context->buffer, "transfer-encoding", NULL);
context->chunked = transfer_encoding && strncmp(transfer_encoding, "chunked", 7) == 0 ? 1 : 0;
const char* content_length_value = get_header(context->buffer, "content-length", NULL);
context->content_length = content_length_value ? atoi(content_length_value) : -1;
context->buffer_length -= (header_end - context->buffer);
if (context->buffer_length > 0)
memmove(context->buffer, header_end, context->buffer_length);
context->chunk_length = !context->chunked && context->content_length == -1 ? INT_MAX : context->content_length;
context->state = STATE_RECV_BODY;
break;
}
}
if (context->buffer_length > 0) {
int to_write = imin(context->chunk_length - context->chunk_written, context->buffer_length);
if (to_write > 0) {
context->total_downloaded += to_write;
context->chunk_written += to_write;
if (context->callback_function) {
lua_rawgeti(L, LUA_REGISTRYINDEX, context->callback_function);
lua_pushinteger(L, context->total_downloaded);
if (context->content_length == -1)
lua_pushnil(L);
else
lua_pushinteger(L, context->content_length);
lua_call(L, 2, 0);
}
case STATE_RECV_BODY: {
while (1) {
// If we have an unknown amount of chunk bytes to be fetched, determine the size of the next chunk.
while (context->chunk_length == -1) {
char* newline = (char*)strnstr_local(context->buffer, "\r\n", context->buffer_length);
if (newline) {
*newline = '\0';
if ((sscanf(context->buffer, "%x", &context->chunk_length) != 1 && lpm_set_error(context, "error retrieving chunk length")))
goto cleanup;
else if (context->chunk_length == 0)
goto finish;
context->chunk_written = 0;
context->buffer_length -= (newline + 2 - context->buffer);
if (context->buffer_length > 0)
memmove(context->buffer, newline + 2, context->buffer_length);
} else if (context->buffer_length >= sizeof(context->buffer) && lpm_set_error(context, "can't find chunk length")) {
goto cleanup;
} else {
int length = lpm_socket_read(context, -1);
if ((length <= 0 || (context->is_ssl && length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)) && lpm_get_error(context, length, "error retrieving full repsonse"))
goto cleanup;
if (length == 0) {
if (is_main_thread)
break;
return lua_yieldk(L, 0, ctx, lpm_getk);
}
}
if (context->file)
fwrite(context->buffer, sizeof(char), to_write, context->file);
else {
lua_rawgeti(L, LUA_REGISTRYINDEX, context->lua_buffer);
lua_pushlstring(L, context->buffer, to_write);
lua_rawseti(L, -2, lua_rawlen(L, -2) + 1);
lua_pop(L, 1);
}
if (context->buffer_length > 0) {
int to_write = imin(context->chunk_length - context->chunk_written, context->buffer_length);
if (to_write > 0) {
context->total_downloaded += to_write;
context->chunk_written += to_write;
if (context->callback_function) {
lua_rawgeti(L, LUA_REGISTRYINDEX, context->callback_function);
lua_pushinteger(L, context->total_downloaded);
if (context->content_length == -1)
lua_pushnil(L);
else
lua_pushinteger(L, context->content_length);
lua_call(L, 2, 0);
}
if (context->file)
fwrite(context->buffer, sizeof(char), to_write, context->file);
else {
lua_rawgeti(L, LUA_REGISTRYINDEX, context->lua_buffer);
lua_pushlstring(L, context->buffer, to_write);
lua_rawseti(L, -2, lua_rawlen(L, -2) + 1);
lua_pop(L, 1);
}
context->buffer_length -= to_write;
if (context->buffer_length > 0)
memmove(context->buffer, &context->buffer[to_write], context->buffer_length);
}
if (context->chunk_written == context->chunk_length) {
if (!context->chunked)
goto finish;
if (context->buffer_length >= 2) {
if (!strnstr_local(context->buffer, "\r\n", 2) && lpm_set_error(context, "invalid end to chunk"))
goto cleanup;
memmove(context->buffer, &context->buffer[2], context->buffer_length - 2);
context->buffer_length -= 2;
context->chunk_length = -1;
}
}
context->buffer_length -= to_write;
if (context->buffer_length > 0)
memmove(context->buffer, &context->buffer[to_write], context->buffer_length);
}
if (context->chunk_written == context->chunk_length) {
if (!context->chunked)
if (context->chunk_length > 0) {
int length = lpm_socket_read(context, imin(sizeof(context->buffer) - context->buffer_length, context->chunk_length - context->chunk_written + (context->chunked ? 2 : 0)));
if ((!context->is_ssl && length == 0) || (context->is_ssl && context->content_length == -1 && length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY))
goto finish;
if (context->buffer_length >= 2) {
if (!strnstr_local(context->buffer, "\r\n", 2) && lpm_set_error(context, "invalid end to chunk"))
goto cleanup;
memmove(context->buffer, &context->buffer[2], context->buffer_length - 2);
context->buffer_length -= 2;
context->chunk_length = -1;
if (length < 0 && lpm_get_error(context, length, "error retrieving full chunk"))
goto cleanup;
if (length == 0) {
if (is_main_thread)
break;
return lua_yieldk(L, 0, ctx, lpm_getk);
}
}
}
if (context->chunk_length > 0) {
int length = lpm_socket_read(context, imin(sizeof(context->buffer) - context->buffer_length, context->chunk_length - context->chunk_written + (context->chunked ? 2 : 0)));
if ((!context->is_ssl && length == 0) || (context->is_ssl && context->content_length == -1 && length == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY))
goto finish;
if (length < 0 && lpm_get_error(context, length, "error retrieving full chunk"))
goto cleanup;
if (length == 0)
return lua_yieldk(L, 0, ctx, lpm_getk);
}
}
default: break;
}
default: break;
}
finish:
if (context->file) {
Expand Down
6 changes: 3 additions & 3 deletions src/lpm.lua
Original file line number Diff line number Diff line change
Expand Up @@ -509,11 +509,11 @@ function common.args(arguments, options)
for k,v in pairs(arguments) do if math.type(k) ~= "integer" then args[k] = v end end
while i <= #arguments do
local s,e, option, value = arguments[i]:find("%-%-([^=]+)=?(.*)")
local option_name = s and option:gsub("^no%-", "")
local option_name = s and (options[option] and option or option:gsub("^no%-", ""))
if options[option_name] then
local flag_type = options[option_name]
if flag_type == "flag" then
args[option] = not option:find("^no-") and true or false
args[option] = (option_name == option or not option:find("^no-")) and true or false
elseif flag_type == "string" or flag_type == "number" or flag_type == "array" then
if not value or value == "" then
if i == #arguments then error("option " .. option .. " requires a " .. flag_type) end
Expand Down Expand Up @@ -2584,7 +2584,7 @@ Flags have the following effects:
for the system lite-xl.
--binary=path Sets the lite-xl binary path for the system lite-xl.
--verbose Spits out more information, including intermediate
steps to install and whatnot.
steps to install and whatnot.
--quiet Outputs nothing but explicit responses.
--mod-version=version Sets the mod version of lite-xl to install addons.
Can be set to "any", which will retrieve the latest
Expand Down

0 comments on commit 5e31990

Please sign in to comment.