diff --git a/Cargo.lock b/Cargo.lock index 11bbd412f8b3..d5cf090157b8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1093,6 +1093,8 @@ dependencies = [ "common-meta-api", "common-meta-embedded", "common-meta-types", + "common-proto-conv", + "common-protos", "mockall", "serde_json", ] diff --git a/common/management/Cargo.toml b/common/management/Cargo.toml index 4728e7971aeb..089b602c63ec 100644 --- a/common/management/Cargo.toml +++ b/common/management/Cargo.toml @@ -20,6 +20,8 @@ common-exception = { path = "../exception" } common-functions = { path = "../functions" } common-meta-api = { path = "../meta/api" } common-meta-types = { path = "../meta/types" } +common-proto-conv = { path = "../proto-conv" } +common-protos = { path = "../protos" } async-trait = "0.1.53" serde_json = "1.0.79" diff --git a/common/management/src/lib.rs b/common/management/src/lib.rs index 21d18b6173f9..d3568b96fa67 100644 --- a/common/management/src/lib.rs +++ b/common/management/src/lib.rs @@ -14,6 +14,7 @@ mod cluster; mod role; +mod serde; mod stage; mod udf; mod user; @@ -22,6 +23,8 @@ pub use cluster::ClusterApi; pub use cluster::ClusterMgr; pub use role::RoleApi; pub use role::RoleMgr; +pub use serde::deserialize_struct; +pub use serde::serialize_struct; pub use stage::StageApi; pub use stage::StageMgr; pub use udf::UdfApi; diff --git a/common/management/src/serde/mod.rs b/common/management/src/serde/mod.rs new file mode 100644 index 000000000000..d2c170844ba6 --- /dev/null +++ b/common/management/src/serde/mod.rs @@ -0,0 +1,18 @@ +// Copyright 2021 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod pb_serde; + +pub use pb_serde::deserialize_struct; +pub use pb_serde::serialize_struct; diff --git a/common/management/src/serde/pb_serde.rs b/common/management/src/serde/pb_serde.rs new file mode 100644 index 000000000000..16922cb0b9a9 --- /dev/null +++ b/common/management/src/serde/pb_serde.rs @@ -0,0 +1,55 @@ +// Copyright 2021 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt::Display; + +use common_exception::ErrorCode; +use common_exception::Result; +use common_exception::ToErrorCode; +use common_proto_conv::FromToProto; + +pub fn serialize_struct( + value: &impl FromToProto, + err_code_fn: ErrFn, + context_fn: CtxFn, +) -> Result> +where + ErrFn: FnOnce(String) -> ErrorCode + std::marker::Copy, + D: Display, + CtxFn: FnOnce() -> D + std::marker::Copy, +{ + let p = value.to_pb().map_err_to_code(err_code_fn, context_fn)?; + let mut buf = vec![]; + common_protos::prost::Message::encode(&p, &mut buf).map_err_to_code(err_code_fn, context_fn)?; + Ok(buf) +} + +pub fn deserialize_struct( + buf: &[u8], + err_code_fn: ErrFn, + context_fn: CtxFn, +) -> Result +where + PB: common_protos::prost::Message + Default, + T: FromToProto, + ErrFn: FnOnce(String) -> ErrorCode + std::marker::Copy, + D: Display, + CtxFn: FnOnce() -> D + std::marker::Copy, +{ + let p: PB = + common_protos::prost::Message::decode(buf).map_err_to_code(err_code_fn, context_fn)?; + let v: T = FromToProto::from_pb(p).map_err_to_code(err_code_fn, context_fn)?; + + Ok(v) +} diff --git a/common/management/src/stage/stage_mgr.rs b/common/management/src/stage/stage_mgr.rs index b63eca2b8996..5b7243bcdcdc 100644 --- a/common/management/src/stage/stage_mgr.rs +++ b/common/management/src/stage/stage_mgr.rs @@ -18,7 +18,6 @@ use common_base::base::escape_for_key; use common_exception::ErrorCode; use common_exception::Result; use common_meta_api::KVApi; -use common_meta_types::IntoSeqV; use common_meta_types::MatchSeq; use common_meta_types::MatchSeqExt; use common_meta_types::OkOrExist; @@ -27,6 +26,8 @@ use common_meta_types::SeqV; use common_meta_types::UpsertKVAction; use common_meta_types::UserStageInfo; +use crate::serde::deserialize_struct; +use crate::serde::serialize_struct; use crate::stage::StageApi; static USER_STAGE_API_KEY_PREFIX: &str = "__fd_stages"; @@ -55,7 +56,11 @@ impl StageMgr { impl StageApi for StageMgr { async fn add_stage(&self, info: UserStageInfo) -> Result { let seq = MatchSeq::Exact(0); - let val = Operation::Update(serde_json::to_vec(&info)?); + let val = Operation::Update(serialize_struct( + &info, + ErrorCode::IllegalUserStageFormat, + || "", + )?); let key = format!( "{}/{}", self.stage_prefix, @@ -85,7 +90,10 @@ impl StageApi for StageMgr { res.ok_or_else(|| ErrorCode::UnknownStage(format!("Unknown stage {}", name)))?; match MatchSeq::from(seq).match_seq(&seq_value) { - Ok(_) => Ok(seq_value.into_seqv()?), + Ok(_) => Ok(SeqV::new( + seq_value.seq, + deserialize_struct(&seq_value.data, ErrorCode::IllegalUserStageFormat, || "")?, + )), Err(_) => Err(ErrorCode::UnknownStage(format!("Unknown stage {}", name))), } } @@ -95,7 +103,8 @@ impl StageApi for StageMgr { let mut stage_infos = Vec::with_capacity(values.len()); for (_, value) in values { - let stage_info = serde_json::from_slice::(&value.data)?; + let stage_info = + deserialize_struct(&value.data, ErrorCode::IllegalUserStageFormat, || "")?; stage_infos.push(stage_info); } Ok(stage_infos) diff --git a/common/management/src/user/user_mgr.rs b/common/management/src/user/user_mgr.rs index 2177a55db553..ecb24d798478 100644 --- a/common/management/src/user/user_mgr.rs +++ b/common/management/src/user/user_mgr.rs @@ -17,11 +17,9 @@ use std::sync::Arc; use common_base::base::escape_for_key; use common_exception::ErrorCode; use common_exception::Result; -use common_exception::ToErrorCode; use common_meta_api::KVApi; use common_meta_types::AuthInfo; use common_meta_types::GrantObject; -use common_meta_types::IntoSeqV; use common_meta_types::MatchSeq; use common_meta_types::MatchSeqExt; use common_meta_types::OkOrExist; @@ -33,6 +31,8 @@ use common_meta_types::UserInfo; use common_meta_types::UserOption; use common_meta_types::UserPrivilegeSet; +use crate::serde::deserialize_struct; +use crate::serde::serialize_struct; use crate::user::user_api::UserApi; static USER_API_KEY_PREFIX: &str = "__fd_users"; @@ -63,7 +63,7 @@ impl UserMgr { ) -> common_exception::Result { let user_key = format_user_key(&user_info.name, &user_info.hostname); let key = format!("{}/{}", self.user_prefix, escape_for_key(&user_key)?); - let value = serde_json::to_vec(&user_info)?; + let value = serialize_struct(user_info, ErrorCode::IllegalUserInfoFormat, || "")?; let match_seq = match seq { None => MatchSeq::GE(1), @@ -95,7 +95,7 @@ impl UserApi for UserMgr { let match_seq = MatchSeq::Exact(0); let user_key = format_user_key(&user_info.name, &user_info.hostname); let key = format!("{}/{}", self.user_prefix, escape_for_key(&user_key)?); - let value = serde_json::to_vec(&user_info)?; + let value = serialize_struct(&user_info, ErrorCode::IllegalUserInfoFormat, || "")?; let kv_api = self.kv_api.clone(); let upsert_kv = kv_api.upsert_kv(UpsertKVAction::new( @@ -122,7 +122,10 @@ impl UserApi for UserMgr { res.ok_or_else(|| ErrorCode::UnknownUser(format!("unknown user {}", user_key)))?; match MatchSeq::from(seq).match_seq(&seq_value) { - Ok(_) => Ok(seq_value.into_seqv()?), + Ok(_) => Ok(SeqV::new( + seq_value.seq, + deserialize_struct(&seq_value.data, ErrorCode::IllegalUserInfoFormat, || "")?, + )), Err(_) => Err(ErrorCode::UnknownUser(format!("unknown user {}", user_key))), } } @@ -133,8 +136,7 @@ impl UserApi for UserMgr { let mut r = vec![]; for (_key, val) in values { - let u = serde_json::from_slice::(&val.data) - .map_err_to_code(ErrorCode::IllegalUserInfoFormat, || "")?; + let u = deserialize_struct(&val.data, ErrorCode::IllegalUserInfoFormat, || "")?; r.push(SeqV::new(val.seq, u)); } diff --git a/common/management/tests/it/stage.rs b/common/management/tests/it/stage.rs index 398568b60b15..98f736927b73 100644 --- a/common/management/tests/it/stage.rs +++ b/common/management/tests/it/stage.rs @@ -15,6 +15,7 @@ use std::sync::Arc; use common_base::base::tokio; +use common_exception::ErrorCode; use common_exception::Result; use common_management::*; use common_meta_api::KVApi; @@ -36,7 +37,10 @@ async fn test_add_stage() -> Result<()> { meta: _, data: value, }) => { - assert_eq!(value, serde_json::to_vec(&stage_info)?); + assert_eq!( + value, + serialize_struct(&stage_info, ErrorCode::IllegalUserStageFormat, || "")? + ); } catch => panic!("GetKVActionReply{:?}", catch), } diff --git a/common/management/tests/it/user.rs b/common/management/tests/it/user.rs index b4e2ecf89395..af6d9844e813 100644 --- a/common/management/tests/it/user.rs +++ b/common/management/tests/it/user.rs @@ -83,8 +83,13 @@ mod add { let test_user_name = "test_user"; let test_hostname = "localhost"; let user_info = UserInfo::new(test_user_name, test_hostname, default_test_auth_info()); - let v = serde_json::to_vec(&user_info)?; - let value = Operation::Update(serde_json::to_vec(&user_info)?); + + let v = serialize_struct(&user_info, ErrorCode::IllegalUserInfoFormat, || "")?; + let value = Operation::Update(serialize_struct( + &user_info, + ErrorCode::IllegalUserInfoFormat, + || "", + )?); let test_key = format!( "__fd_users/tenant1/{}", @@ -188,7 +193,7 @@ mod get { ); let user_info = UserInfo::new(test_user_name, test_hostname, default_test_auth_info()); - let value = serde_json::to_vec(&user_info)?; + let value = serialize_struct(&user_info, ErrorCode::IllegalUserInfoFormat, || "")?; let mut kv = MockKV::new(); kv.expect_get_kv() @@ -214,7 +219,7 @@ mod get { ); let user_info = UserInfo::new(test_user_name, test_hostname, default_test_auth_info()); - let value = serde_json::to_vec(&user_info)?; + let value = serialize_struct(&user_info, ErrorCode::IllegalUserInfoFormat, || "")?; let mut kv = MockKV::new(); kv.expect_get_kv() @@ -333,7 +338,10 @@ mod get_users { let user_info = UserInfo::new(&name, &hostname, default_test_auth_info()); res.push(( "fake_key".to_string(), - SeqV::new(i, serde_json::to_vec(&user_info)?), + SeqV::new( + i, + serialize_struct(&user_info, ErrorCode::IllegalUserInfoFormat, || "")?, + ), )); user_infos.push(SeqV::new(i, user_info)); } @@ -489,7 +497,7 @@ mod update { let test_seq = None; let user_info = UserInfo::new(test_user_name, test_hostname, default_test_auth_info()); - let prev_value = serde_json::to_vec(&user_info)?; + let prev_value = serialize_struct(&user_info, ErrorCode::IllegalUserInfoFormat, || "")?; // get_kv should be called let mut kv = MockKV::new(); @@ -503,7 +511,8 @@ mod update { // and then, update_kv should be called let new_user_info = UserInfo::new(test_user_name, test_hostname, new_test_auth_info(full)); - let new_value_with_old_salt = serde_json::to_vec(&new_user_info)?; + let new_value_with_old_salt = + serialize_struct(&new_user_info, ErrorCode::IllegalUserInfoFormat, || "")?; kv.expect_upsert_kv() .with(predicate::eq(UpsertKVAction::new( @@ -574,7 +583,7 @@ mod update { let test_seq = None; let user_info = UserInfo::new(test_user_name, test_hostname, default_test_auth_info()); - let prev_value = serde_json::to_vec(&user_info)?; + let prev_value = serialize_struct(&user_info, ErrorCode::IllegalUserInfoFormat, || "")?; // - get_kv should be called let mut kv = MockKV::new(); @@ -630,7 +639,7 @@ mod set_user_privileges { let test_seq = None; let mut user_info = UserInfo::new(test_user_name, test_hostname, default_test_auth_info()); - let prev_value = serde_json::to_vec(&user_info)?; + let prev_value = serialize_struct(&user_info, ErrorCode::IllegalUserInfoFormat, || "")?; // - get_kv should be called let mut kv = MockKV::new(); @@ -647,7 +656,7 @@ mod set_user_privileges { user_info .grants .grant_privileges(&GrantObject::Global, privileges); - let new_value = serde_json::to_vec(&user_info)?; + let new_value = serialize_struct(&user_info, ErrorCode::IllegalUserInfoFormat, || "")?; kv.expect_upsert_kv() .with(predicate::eq(UpsertKVAction::new(