diff --git a/examples/Cargo.lock b/examples/Cargo.lock index 79b1aa9..91c54d2 100644 --- a/examples/Cargo.lock +++ b/examples/Cargo.lock @@ -3775,7 +3775,7 @@ dependencies = [ [[package]] name = "viam-rust-utils" -version = "0.2.0" +version = "0.2.4" dependencies = [ "anyhow", "base64 0.13.1", diff --git a/src/ffi/dial_ffi.rs b/src/ffi/dial_ffi.rs index 7deabbe..4529a57 100644 --- a/src/ffi/dial_ffi.rs +++ b/src/ffi/dial_ffi.rs @@ -196,10 +196,13 @@ pub unsafe extern "C" fn dial( false => match CStr::from_ptr(c_entity).to_str() { Ok(ent) => Some(ent.to_string()), Err(e) => { - log::error!("Error unexpectedly received an invalid entity string {:?}", e); + log::error!( + "Error unexpectedly received an invalid entity string {:?}", + e + ); return ptr::null_mut(); } - } + }, } }; let timeout_duration = Duration::from_secs_f32(c_timeout); @@ -207,26 +210,26 @@ pub unsafe extern "C" fn dial( let (server, channel) = match runtime.block_on(async move { let channel = match (r#type, payload) { (Some(t), Some(p)) => { - let res = timeout( - timeout_duration, + timeout( + timeout_duration, dial_with_cred( - uri_str, - entity_opt, - t.to_str()?, - p.to_str()?, - allow_insec, - disable_webrtc, - )? - .connect() + uri_str, + entity_opt, + t.to_str()?, + p.to_str()?, + allow_insec, + disable_webrtc, + )? + .connect(), ) - .await?; - res + .await? } (None, None) => { - let res = timeout( - timeout_duration, - dial_without_cred(uri_str, allow_insec, disable_webrtc)?.connect()).await?; - res + timeout( + timeout_duration, + dial_without_cred(uri_str, allow_insec, disable_webrtc)?.connect(), + ) + .await? } (None, Some(_)) => Err(anyhow::anyhow!("Error missing credential: type")), (Some(_), None) => Err(anyhow::anyhow!("Error missing credential: payload")), diff --git a/src/rpc/dial.rs b/src/rpc/dial.rs index 171f4dc..39efadb 100644 --- a/src/rpc/dial.rs +++ b/src/rpc/dial.rs @@ -603,7 +603,7 @@ impl DialBuilder { let original_uri = Uri::from_parts(original_uri_parts)?; - let domain = original_uri.authority().clone().unwrap().to_string(); + let domain = original_uri.authority().unwrap().to_string(); let uri_for_auth = infer_remote_uri_from_authority(original_uri.clone()); let mdns_uri = mdns_uri.and_then(|p| Uri::from_parts(p).ok()); @@ -621,7 +621,7 @@ impl DialBuilder { // created with the default uri None => Err(anyhow::anyhow!("")), }; - let real_channel = match channel { + let mut real_channel = match channel { Ok(c) => { log::debug!("Connected via mDNS"); c @@ -632,26 +632,40 @@ impl DialBuilder { "Unable to connect via mDNS; falling back to robot URI. Error: {e}" ); } - Self::create_channel(allow_downgrade, &domain, uri_for_auth, false).await? + Self::create_channel(allow_downgrade, &domain, uri_for_auth.clone(), false).await? } }; log::debug!("{}", log_prefixes::ACQUIRING_AUTH_TOKEN); - let token = get_auth_token( - &mut real_channel.clone(), - self.config - .credentials - .as_ref() - .unwrap() - .credentials - .clone(), - self.config - .credentials - .unwrap() - .entity - .unwrap_or_else(|| domain.clone()), - ) - .await?; + let creds = self + .config + .credentials + .as_ref() + .unwrap() + .credentials + .clone(); + let entity = self + .config + .credentials + .unwrap() + .entity + .unwrap_or_else(|| domain.clone()); + let token = match get_auth_token(&mut real_channel.clone(), creds.clone(), entity.clone()) + .await + { + Ok(t) => t, + Err(e) => { + if !attempting_mdns { + return Err(e); + } + log::debug!( + "Error getting auth token: [{e}]. This may be the result of attempting to connect via mDNS with an incorrectly scoped API key. Attempting again without mDNS uri." + ); + real_channel = + Self::create_channel(allow_downgrade, &domain, uri_for_auth, false).await?; + get_auth_token(&mut real_channel.clone(), creds, entity).await? + } + }; log::debug!("{}", log_prefixes::ACQUIRED_AUTH_TOKEN); let channel = ServiceBuilder::new() @@ -660,7 +674,7 @@ impl DialBuilder { HeaderName::from_static("rpc-host"), HeaderValue::from_str(domain.as_str())?, )) - .service(real_channel.clone()); + .service(real_channel); if disable_webrtc { log::debug!("Connected via gRPC");