-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #10 from opensass/get-mod
feat: impl `get-model` endpoint
- Loading branch information
Showing
4 changed files
with
91 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
//! Reference: https://docs.x.ai/api/endpoints#get-model | ||
use crate::error::check_for_model_error; | ||
use crate::error::XaiError; | ||
use crate::traits::{ClientConfig, ModelInfoFetcher}; | ||
use reqwest::Method; | ||
use serde::{Deserialize, Serialize}; | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
pub struct ModelInfoResponse { | ||
pub created: u64, | ||
pub id: String, | ||
pub object: String, | ||
pub owned_by: String, | ||
} | ||
|
||
#[derive(Debug, Clone)] | ||
pub struct ModelRequestBuilder<T: ClientConfig + Clone + Send + Sync> { | ||
client: T, | ||
model_id: String, | ||
} | ||
|
||
impl<T> ModelRequestBuilder<T> | ||
where | ||
T: ClientConfig + Clone + Send + Sync, | ||
{ | ||
pub fn new(client: T, model_id: String) -> Self { | ||
Self { client, model_id } | ||
} | ||
} | ||
|
||
impl<T> ModelInfoFetcher for ModelRequestBuilder<T> | ||
where | ||
T: ClientConfig + Clone + Send + Sync, | ||
{ | ||
async fn fetch_model_info(&self) -> Result<ModelInfoResponse, XaiError> { | ||
let url = format!("models/{}", self.model_id); | ||
|
||
let response = self.client.request(Method::GET, &url)?.send().await?; | ||
|
||
if response.status().is_success() { | ||
let body = response.text().await?; | ||
let chat_completion = serde_json::from_str::<ModelInfoResponse>(&body)?; | ||
Ok(chat_completion) | ||
} else { | ||
let error_body = response.text().await.unwrap_or_else(|_| "".to_string()); | ||
|
||
if let Some(model_error) = check_for_model_error(&error_body) { | ||
return Err(model_error); | ||
} | ||
|
||
Err(XaiError::Http(error_body)) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
use std::env; | ||
use x_ai::client::XaiClient; | ||
use x_ai::get_mod::ModelRequestBuilder; | ||
use x_ai::traits::{ClientConfig, ModelInfoFetcher}; | ||
|
||
#[tokio::test] | ||
async fn test_fetch_model_info() { | ||
let client = XaiClient::builder() | ||
.build() | ||
.expect("Failed to build XaiClient"); | ||
|
||
client.set_api_key( | ||
env::var("XAI_API_KEY") | ||
.expect("XAI_API_KEY must be set!") | ||
.to_string(), | ||
); | ||
|
||
let model_id = "grok-beta".to_string(); | ||
let request_builder = ModelRequestBuilder::new(client, model_id); | ||
|
||
let result = request_builder.fetch_model_info().await; | ||
|
||
assert!(result.is_ok()); | ||
|
||
let model_info = result.unwrap(); | ||
assert_eq!(model_info.id, "grok-beta"); | ||
assert_eq!(model_info.object, "model"); | ||
assert_eq!(model_info.owned_by, "xai"); | ||
assert_eq!(model_info.created, 1727136000); | ||
} |