diff --git a/VM/include/lua.h b/VM/include/lua.h index a2c722135..a556cc4b6 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -294,6 +294,7 @@ LUA_API void lua_unref(lua_State* L, int ref); #define lua_islightuserdata(L, n) (lua_type(L, (n)) == LUA_TLIGHTUSERDATA) #define lua_isnil(L, n) (lua_type(L, (n)) == LUA_TNIL) #define lua_isboolean(L, n) (lua_type(L, (n)) == LUA_TBOOLEAN) +#define lua_isvector(L, n) (lua_type(L, (n)) == LUA_TVECTOR) #define lua_isthread(L, n) (lua_type(L, (n)) == LUA_TTHREAD) #define lua_isnone(L, n) (lua_type(L, (n)) == LUA_TNONE) #define lua_isnoneornil(L, n) (lua_type(L, (n)) <= LUA_TNIL) diff --git a/VM/include/lualib.h b/VM/include/lualib.h index 54b008ffb..baf27b47e 100644 --- a/VM/include/lualib.h +++ b/VM/include/lualib.h @@ -33,6 +33,9 @@ LUALIB_API int luaL_optinteger(lua_State* L, int nArg, int def); LUALIB_API unsigned luaL_checkunsigned(lua_State* L, int numArg); LUALIB_API unsigned luaL_optunsigned(lua_State* L, int numArg, unsigned def); +LUALIB_API const float* luaL_checkvector(lua_State* L, int narg); +LUALIB_API const float* luaL_optvector(lua_State* L, int narg, const float* def); + LUALIB_API void luaL_checkstack(lua_State* L, int sz, const char* msg); LUALIB_API void luaL_checktype(lua_State* L, int narg, int t); LUALIB_API void luaL_checkany(lua_State* L, int narg); diff --git a/VM/src/laux.cpp b/VM/src/laux.cpp index a5e54358d..7ed2a62ee 100644 --- a/VM/src/laux.cpp +++ b/VM/src/laux.cpp @@ -227,6 +227,19 @@ unsigned luaL_optunsigned(lua_State* L, int narg, unsigned def) return luaL_opt(L, luaL_checkunsigned, narg, def); } +const float* luaL_checkvector(lua_State* L, int narg) +{ + const float* v = lua_tovector(L, narg); + if (!v) + tag_error(L, narg, LUA_TVECTOR); + return v; +} + +const float* luaL_optvector(lua_State* L, int narg, const float* def) +{ + return luaL_opt(L, luaL_checkvector, narg, def); +} + int luaL_getmetafield(lua_State* L, int obj, const char* event) { if (!lua_getmetatable(L, obj)) /* no metatable? */ diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index a573ae42d..5ee842449 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -78,38 +78,31 @@ static int lua_vector(lua_State* L) static int lua_vector_dot(lua_State* L) { - const float* a = lua_tovector(L, 1); - const float* b = lua_tovector(L, 2); + const float* a = luaL_checkvector(L, 1); + const float* b = luaL_checkvector(L, 2); - if (a && b) - { - lua_pushnumber(L, a[0] * b[0] + a[1] * b[1] + a[2] * b[2]); - return 1; - } - - throw std::runtime_error("invalid arguments to vector:Dot"); + lua_pushnumber(L, a[0] * b[0] + a[1] * b[1] + a[2] * b[2]); + return 1; } static int lua_vector_index(lua_State* L) { + const float* v = luaL_checkvector(L, 1); const char* name = luaL_checkstring(L, 2); - if (const float* v = lua_tovector(L, 1)) + if (strcmp(name, "Magnitude") == 0) { - if (strcmp(name, "Magnitude") == 0) - { - lua_pushnumber(L, sqrtf(v[0] * v[0] + v[1] * v[1] + v[2] * v[2])); - return 1; - } + lua_pushnumber(L, sqrtf(v[0] * v[0] + v[1] * v[1] + v[2] * v[2])); + return 1; + } - if (strcmp(name, "Dot") == 0) - { - lua_pushcfunction(L, lua_vector_dot, "Dot"); - return 1; - } + if (strcmp(name, "Dot") == 0) + { + lua_pushcfunction(L, lua_vector_dot, "Dot"); + return 1; } - throw std::runtime_error(Luau::format("%s is not a valid member of vector", name)); + luaL_error(L, "%s is not a valid member of vector", name); } static int lua_vector_namecall(lua_State* L) @@ -120,7 +113,7 @@ static int lua_vector_namecall(lua_State* L) return lua_vector_dot(L); } - throw std::runtime_error(Luau::format("%s is not a valid method of vector", luaL_checkstring(L, 1))); + luaL_error(L, "%s is not a valid method of vector", luaL_checkstring(L, 1)); } int lua_silence(lua_State* L)