Skip to content

Commit

Permalink
Merge pull request #10 from opensass/get-mod
Browse files Browse the repository at this point in the history
feat: impl `get-model` endpoint
  • Loading branch information
wiseaidev authored Nov 18, 2024
2 parents 8ef7804 + c6f4a72 commit c60d580
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 0 deletions.
55 changes: 55 additions & 0 deletions src/get_mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//! Reference: https://docs.x.ai/api/endpoints#get-model
use crate::error::check_for_model_error;
use crate::error::XaiError;
use crate::traits::{ClientConfig, ModelInfoFetcher};
use reqwest::Method;
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelInfoResponse {
pub created: u64,
pub id: String,
pub object: String,
pub owned_by: String,
}

#[derive(Debug, Clone)]
pub struct ModelRequestBuilder<T: ClientConfig + Clone + Send + Sync> {
client: T,
model_id: String,
}

impl<T> ModelRequestBuilder<T>
where
T: ClientConfig + Clone + Send + Sync,
{
pub fn new(client: T, model_id: String) -> Self {
Self { client, model_id }
}
}

impl<T> ModelInfoFetcher for ModelRequestBuilder<T>
where
T: ClientConfig + Clone + Send + Sync,
{
async fn fetch_model_info(&self) -> Result<ModelInfoResponse, XaiError> {
let url = format!("models/{}", self.model_id);

let response = self.client.request(Method::GET, &url)?.send().await?;

if response.status().is_success() {
let body = response.text().await?;
let chat_completion = serde_json::from_str::<ModelInfoResponse>(&body)?;
Ok(chat_completion)
} else {
let error_body = response.text().await.unwrap_or_else(|_| "".to_string());

if let Some(model_error) = check_for_model_error(&error_body) {
return Err(model_error);
}

Err(XaiError::Http(error_body))
}
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub mod embedding;
pub mod embedding_get;
pub mod embedding_mod;
pub mod error;
pub mod get_mod;
pub mod lang_mod;
pub mod list_lang_mod;
pub mod list_mod;
Expand Down
5 changes: 5 additions & 0 deletions src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::embedding::EmbeddingResponse;
use crate::embedding_get::EmbeddingModelResponse;
use crate::embedding_mod::EmbeddingModelsResponse;
use crate::error::XaiError;
use crate::get_mod::ModelInfoResponse;
use crate::lang_mod::LanguageModelDetailResponse;
use crate::list_lang_mod::LanguageModelListResponse;
use crate::list_mod::ReducedModelListResponse;
Expand Down Expand Up @@ -65,3 +66,7 @@ pub trait GetModelFetcher {
pub trait ListModelFetcher {
async fn fetch_model_info(&self) -> Result<ReducedModelListResponse, XaiError>;
}

pub trait ModelInfoFetcher {
async fn fetch_model_info(&self) -> Result<ModelInfoResponse, XaiError>;
}
30 changes: 30 additions & 0 deletions tests/get_mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use std::env;
use x_ai::client::XaiClient;
use x_ai::get_mod::ModelRequestBuilder;
use x_ai::traits::{ClientConfig, ModelInfoFetcher};

#[tokio::test]
async fn test_fetch_model_info() {
let client = XaiClient::builder()
.build()
.expect("Failed to build XaiClient");

client.set_api_key(
env::var("XAI_API_KEY")
.expect("XAI_API_KEY must be set!")
.to_string(),
);

let model_id = "grok-beta".to_string();
let request_builder = ModelRequestBuilder::new(client, model_id);

let result = request_builder.fetch_model_info().await;

assert!(result.is_ok());

let model_info = result.unwrap();
assert_eq!(model_info.id, "grok-beta");
assert_eq!(model_info.object, "model");
assert_eq!(model_info.owned_by, "xai");
assert_eq!(model_info.created, 1727136000);
}

0 comments on commit c60d580

Please sign in to comment.