Skip to content

Commit

Permalink
Support tag and creating experiment (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
taku-y committed Mar 30, 2024
1 parent b53642b commit e8ec329
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 21 deletions.
94 changes: 73 additions & 21 deletions border-mlflow-tracking/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::{system_time_as_millis, Experiment, MlflowTrackingRecorder, Run};
use anyhow::Result;
use log::info;
use reqwest::blocking::Client;
use reqwest::Request;
use serde::{Deserialize, Serialize};
use std::error::Error;
use std::fmt::Display;
Expand Down Expand Up @@ -40,6 +41,11 @@ struct CreateRunParams {
run_name: String,
}

#[derive(Debug, Serialize)]
struct CreateExperimentParams {
name: String,
}

/// Provides access to a MLflow tracking server via REST API.
///
/// Support Mlflow API version 2.0.
Expand Down Expand Up @@ -118,45 +124,91 @@ impl MlflowTrackingClient {

/// Get [`Experiment`] by name from the tracking server.
///
/// If the experiment with given name does not exist in the trackingserver,
/// it will be created.
///
/// TODO: Better error handling
pub fn get_experiment(&self, name: impl AsRef<str>) -> Option<Experiment> {
let url = format!("{}/api/2.0/mlflow/experiments/get-by-name", self.base_url);
let resp = self
.client
let resp = match self.get(
self.url("experiments/get-by-name"),
&[("experiment_name", name.as_ref())],
) {
Ok(resp) => {
if resp.status().is_success() {
resp
} else {
// if the experiment does not exist, create it
self.post(
self.url("experiments/create"),
&CreateExperimentParams {
name: name.as_ref().into(),
},
)
.unwrap();
self.get(
self.url("experiments/get-by-name"),
&[("experiment_name", name.as_ref())],
)
.unwrap()
}
}
Err(_) => {
panic!();
}
};
let experiment: Experiment_ = serde_json::from_str(resp.text().unwrap().as_str()).unwrap();

Some(experiment.experiment)
}

fn url(&self, api: impl AsRef<str>) -> String {
format!("{}/api/2.0/mlflow/{}", self.base_url, api.as_ref())
}

fn get(
&self,
url: String,
query: &impl Serialize,
) -> reqwest::Result<reqwest::blocking::Response> {
self.client
.get(url)
.basic_auth(&self.user_name, Some(&self.password))
.query(&[("experiment_name", name.as_ref())])
.query(query)
.send()
.unwrap();
let experiment: Experiment_ = serde_json::from_str(resp.text().unwrap().as_str()).unwrap();
}

Some(experiment.experiment)
fn post(
&self,
url: String,
params: &impl Serialize,
) -> reqwest::Result<reqwest::blocking::Response> {
self.client
.post(url)
.basic_auth(&self.user_name, Some(&self.password))
.json(&params) // auto serialize
.send()
}

/// Create [`MlflowTrackingRecorder`] corresponding to a run.
///
/// If `name` is empty (`""`), a run name is given by the tracking server.
/// If `name` is empty (`""`), a run name is generated by the tracking server.
///
/// Need to call [`MlflowTrackingClient::set_experiment_id()`] before calling this method.
pub fn create_recorder(&self, run_name: impl AsRef<str>) -> Result<MlflowTrackingRecorder> {
let not_given_name = run_name.as_ref().len() == 0;
let experiment_id = self.experiment_id.as_ref().expect("Needs experiment_id");
let url = format!("{}/api/2.0/mlflow/runs/create", self.base_url);
let params = CreateRunParams {
experiment_id: experiment_id.to_string(),
start_time: system_time_as_millis() as i64,
run_name: run_name.as_ref().to_string(),
};
let resp = self
.client
.post(url)
.basic_auth(&self.user_name, Some(&self.password))
.json(&params) // auto serialize
.send()
.post(
self.url("runs/create"),
&CreateRunParams {
experiment_id: experiment_id.to_string(),
start_time: system_time_as_millis() as i64,
run_name: run_name.as_ref().to_string(),
},
)
.unwrap();
// println!("{:?}", resp);
// println!("{:?}", resp.text());
// TODO: Check the response from the tracking server here

let run = {
let run: Run_ =
serde_json::from_str(&resp.text().unwrap()).expect("Failed to deserialize Run");
Expand Down
26 changes: 26 additions & 0 deletions border-mlflow-tracking/src/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ struct UpdateRunParams<'a> {
run_name: &'a String,
}

#[derive(Debug, Serialize)]
struct SetTagParams<'a> {
run_id: &'a String,
key: &'a String,
value: &'a String,
}

#[allow(dead_code)]
/// Record metrics to the MLflow tracking server during training.
///
Expand Down Expand Up @@ -89,6 +96,7 @@ impl MlflowTrackingRecorder {
let _resp = self
.client
.post(&url)
.basic_auth(&self.user_name, Some(&self.password))
.json(&params) // auto serialize
.send()
.unwrap();
Expand All @@ -97,6 +105,24 @@ impl MlflowTrackingRecorder {

Ok(())
}

pub fn set_tag(&self, key: impl AsRef<str>, value: impl AsRef<str>) -> Result<()> {
let url = format!("{}/api/2.0/mlflow/runs/set-tag", self.base_url);
let params = SetTagParams {
run_id: &self.run_id,
key: &key.as_ref().to_string(),
value: &value.as_ref().to_string(),
};
let _resp = self
.client
.post(&url)
.basic_auth(&self.user_name, Some(&self.password))
.json(&params)
.send()
.unwrap();

Ok(())
}
}

impl Recorder for MlflowTrackingRecorder {
Expand Down

0 comments on commit e8ec329

Please sign in to comment.