Skip to content

Commit

Permalink
feat: rename base_url env variable to inference_url
Browse files Browse the repository at this point in the history
  • Loading branch information
moldhouse committed Dec 10, 2024
1 parent a97e02a commit b9a2fd2
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
PHARIA_AI_TOKEN=
AA_BASE_URL=
INFERENCE_URL=
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ jobs:
- name: Run tests
env:
PHARIA_AI_TOKEN: ${{ secrets.PHARIA_AI_TOKEN }}
AA_BASE_URL: https://inference-api.product.pharia.com
INFERENCE_URL: https://inference-api.product.pharia.com
run: cargo test
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ impl Client {
pub fn from_env() -> Result<Self, Error> {
let _ = dotenv();
let api_token = env::var("PHARIA_AI_TOKEN").unwrap();
let base_url = env::var("AA_BASE_URL").unwrap();
Self::with_base_url(base_url, api_token)
let inference_url = env::var("INFERENCE_URL").unwrap();
Self::with_base_url(inference_url, api_token)
}

/// Execute a task with the aleph alpha API and fetch its result.
Expand Down
52 changes: 26 additions & 26 deletions tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ fn pharia_ai_token() -> &'static str {
&PHARIA_AI_TOKEN
}

fn base_url() -> &'static str {
static AA_BASE_URL: LazyLock<String> = LazyLock::new(|| {
fn inference_url() -> &'static str {
static INFERENCE_URL: LazyLock<String> = LazyLock::new(|| {
drop(dotenv());
std::env::var("AA_BASE_URL")
.expect("AA_BASE_URL environment variable must be specified to run tests.")
std::env::var("INFERENCE_URL")
.expect("INFERENCE_URL environment variable must be specified to run tests.")
});
&AA_BASE_URL
&INFERENCE_URL
}

#[tokio::test]
Expand All @@ -36,7 +36,7 @@ async fn chat_with_pharia_1_7b_base() {
let task = TaskChat::with_message(message);

let model = "pharia-1-llm-7b-control";
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();
let response = client.chat(&task, model, &How::default()).await.unwrap();

// Then
Expand All @@ -49,7 +49,7 @@ async fn completion_with_luminous_base() {
let task = TaskCompletion::from_text("Hello").with_maximum_tokens(1);

let model = "luminous-base";
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
Expand All @@ -67,7 +67,7 @@ async fn request_authentication_has_priority() {
let task = TaskCompletion::from_text("Hello").with_maximum_tokens(1);

let model = "luminous-base";
let client = Client::with_base_url(base_url(), bad_pharia_ai_token).unwrap();
let client = Client::with_base_url(inference_url(), bad_pharia_ai_token).unwrap();
let response = client
.output_of(
&task.with_model(model),
Expand All @@ -92,7 +92,7 @@ async fn authentication_only_per_request() {
let task = TaskCompletion::from_text("Hello").with_maximum_tokens(1);

// When
let client = Client::new(base_url().to_owned(), None).unwrap();
let client = Client::new(inference_url().to_owned(), None).unwrap();
let response = client
.output_of(
&task.with_model(model),
Expand All @@ -116,7 +116,7 @@ async fn must_panic_if_authentication_is_missing() {
let task = TaskCompletion::from_text("Hello").with_maximum_tokens(1);

// When
let client = Client::new(base_url().to_owned(), None).unwrap();
let client = Client::new(inference_url().to_owned(), None).unwrap();
client
.output_of(&task.with_model(model), &How::default())
.await
Expand All @@ -140,7 +140,7 @@ async fn semanitc_search_with_luminous_base() {
temperature, traditionally in a wood-fired oven.",
);
let query = Prompt::from_text("What is Pizza?");
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();

// When
let robot_embedding_task = TaskSemanticEmbedding {
Expand Down Expand Up @@ -203,7 +203,7 @@ async fn complete_structured_prompt() {
sampling: Sampling::MOST_LIKELY,
};
let model = "luminous-base";
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
Expand Down Expand Up @@ -232,7 +232,7 @@ async fn maximum_tokens_none_request() {
sampling: Sampling::MOST_LIKELY,
};
let model = "luminous-base";
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
Expand All @@ -253,7 +253,7 @@ async fn explain_request() {
target: " How is it going?",
granularity: Granularity::default().with_prompt_granularity(PromptGranularity::Sentence),
};
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();

// When
let response = client
Expand Down Expand Up @@ -283,7 +283,7 @@ async fn explain_request_with_auto_granularity() {
target: " How is it going?",
granularity: Granularity::default(),
};
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();

// When
let response = client
Expand Down Expand Up @@ -315,7 +315,7 @@ async fn explain_request_with_image_modality() {
target: " a cat.",
granularity: Granularity::default().with_prompt_granularity(PromptGranularity::Paragraph),
};
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();

// When
let response = client
Expand Down Expand Up @@ -365,7 +365,7 @@ async fn describe_image_starting_from_a_path() {
sampling: Sampling::MOST_LIKELY,
};
let model = "luminous-base";
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
Expand Down Expand Up @@ -394,7 +394,7 @@ async fn describe_image_starting_from_a_dyn_image() {
sampling: Sampling::MOST_LIKELY,
};
let model = "luminous-base";
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
Expand All @@ -420,7 +420,7 @@ async fn only_answer_with_specific_animal() {
},
};
let model = "luminous-base";
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
Expand All @@ -447,7 +447,7 @@ async fn answer_should_continue() {
},
};
let model = "luminous-base";
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
Expand All @@ -474,7 +474,7 @@ async fn batch_semanitc_embed_with_luminous_base() {
temperature, traditionally in a wood-fired oven.",
);

let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();

// When
let embedding_task = TaskBatchSemanticEmbedding {
Expand All @@ -499,7 +499,7 @@ async fn tokenization_with_luminous_base() {
// Given
let input = "Hello, World!";

let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();

// When
let task1 = TaskTokenization::new(input, false, true);
Expand Down Expand Up @@ -536,7 +536,7 @@ async fn detokenization_with_luminous_base() {
// Given
let input = vec![49222, 15, 5390, 4];

let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();

// When
let task = TaskDetokenization { token_ids: &input };
Expand All @@ -553,7 +553,7 @@ async fn detokenization_with_luminous_base() {
#[tokio::test]
async fn fetch_tokenizer_for_pharia_1_llm_7b() {
// Given
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();

// When
let tokenizer = client
Expand All @@ -568,7 +568,7 @@ async fn fetch_tokenizer_for_pharia_1_llm_7b() {
#[tokio::test]
async fn stream_completion() {
// Given a streaming completion task
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();
let task = TaskCompletion::from_text("").with_maximum_tokens(7);

// When the events are streamed and collected
Expand Down Expand Up @@ -601,7 +601,7 @@ async fn stream_completion() {
#[tokio::test]
async fn stream_chat_with_pharia_1_llm_7b() {
// Given a streaming completion task
let client = Client::with_base_url(base_url(), pharia_ai_token()).unwrap();
let client = Client::with_base_url(inference_url(), pharia_ai_token()).unwrap();
let message = Message::user("Hello,");
let task = TaskChat::with_messages(vec![message]).with_maximum_tokens(7);

Expand Down

0 comments on commit b9a2fd2

Please sign in to comment.