Skip to content

Commit

Permalink
web-client/Trainer: inline disco
Browse files Browse the repository at this point in the history
  • Loading branch information
tharvik committed Apr 29, 2024
1 parent 1279a9c commit 76f189f
Showing 1 changed file with 120 additions and 142 deletions.
262 changes: 120 additions & 142 deletions web-client/src/components/training/Trainer.vue
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,9 @@
<div class="space-y-4 md:space-y-8">
<!-- Train Button -->
<div class="flex justify-center">
<IconCard
title-placement="center"
class="w-3/5"
>
<template #title>
Control the Training Flow
</template>
<template
v-if="!startedTraining"
#content
>
<IconCard title-placement="center" class="w-3/5">
<template #title> Control the Training Flow </template>
<template v-if="training === undefined" #content>
<div class="grid grid-cols-2 gap-8">
<CustomButton @click="startTraining(false)">
train alone
Expand All @@ -22,14 +14,9 @@
</CustomButton>
</div>
</template>
<template
v-else
#content
>
<template v-else #content>
<div class="flex justify-center">
<CustomButton @click="pauseTraining()">
stop <span v-if="distributedTraining">collaborative training</span><span v-else>training</span>
</CustomButton>
<CustomButton @click="stopTraining()"> stop training </CustomButton>
</div>
</template>
</IconCard>
Expand All @@ -45,133 +32,124 @@
</div>
</template>

<script lang="ts">
import { List } from 'immutable'
import { defineComponent } from 'vue'
import type { RoundLogs, Task } from '@epfml/discojs-core'
import { data, isTask, Disco, client as clients } from '@epfml/discojs-core'
// TODO @s314cy: move to discojs-core/src/client/get.ts
import { getClient } from '@/clients'
import { useToaster } from '@/composables/toaster'
import TrainingInformation from '@/components/training/TrainingInformation.vue'
import CustomButton from '@/components/simple/CustomButton.vue'
import IconCard from '@/components/containers/IconCard.vue'
const toaster = useToaster()
export default defineComponent({
name: 'Trainer',
components: {
TrainingInformation,
CustomButton,
IconCard
},
props: {
task: {
validator: isTask,
default: undefined as Task | undefined
},
datasetBuilder: data.DatasetBuilder
},
data (): {
distributedTraining: boolean,
startedTraining: boolean,
logs: List<RoundLogs & { participants: number }>,
messages: List<string>,
} {
return {
distributedTraining: false,
startedTraining: false,
logs: List(),
messages: List(),
}
},
computed: {
client (): clients.Client {
return getClient(this.scheme, this.task)
},
disco (): Disco {
return new Disco(
this.task,
{
logger: {
success: (msg: string) => {
this.messages = this.messages.push(msg)
},
error: (msg: string) => {
this.messages = this.messages.push(msg)
},
},
scheme: this.scheme,
client: this.client
}
<script lang="ts" setup>
import { List } from "immutable";
import { ref, computed } from "vue";
import type { RoundLogs, Task } from "@epfml/discojs-core";
import {
aggregator as aggregators,
client as clients,
data,
EmptyMemory,
Disco,
} from "@epfml/discojs-core";
import { IndexedDB } from "@epfml/discojs";
import { CONFIG } from "@/config";
import { useMemoryStore } from "@/store/memory";
import { useToaster } from "@/composables/toaster";
import TrainingInformation from "@/components/training/TrainingInformation.vue";
import CustomButton from "@/components/simple/CustomButton.vue";
import IconCard from "@/components/containers/IconCard.vue";
const toaster = useToaster();
const memoryStore = useMemoryStore();
const props = defineProps<{
task: Task;
datasetBuilder: data.DatasetBuilder<File>;
}>();
const training =
ref<AsyncGenerator<RoundLogs & { participants: number }, void>>();
const logs = ref(List<RoundLogs & { participants: number }>());
const messages = ref(List<string>());
const hasValidationData = computed(
() => props.task.trainingInformation.validationSplit > 0,
);
async function startTraining(distributed: boolean): Promise<void> {
let dataset: data.DataSplit;
try {
dataset = await props.datasetBuilder.build({
shuffle: false,
validationSplit: props.task.trainingInformation.validationSplit,
});
} catch (e) {
console.error(e);
if (
e instanceof Error &&
e.message.includes(
"provided in columnConfigs does not match any of the column names",
)
},
scheme (): Task['trainingInformation']['scheme'] {
if (this.distributedTraining && this.task.trainingInformation?.scheme !== undefined) {
return this.task.trainingInformation?.scheme
}
) {
// missing field is specified between two "quotes"
const missingFields: String = e.message.split('"')[1].split('"')[0];
toaster.error(`The input data is missing the field "${missingFields}"`);
} else {
toaster.error(
"Incorrect data format. Please check the expected format at the previous step.",
);
}
return;
}
// default scheme
return 'local'
},
hasValidationData (): boolean {
return this.task?.trainingInformation?.validationSplit > 0
},
},
methods: {
async startTraining (distributedTraining: boolean): Promise<void> {
this.distributedTraining = distributedTraining
if (this.datasetBuilder === undefined) {
throw new Error('no dataset builder')
}
let dataset: data.DataSplit
try {
dataset = await this.datasetBuilder.build({
shuffle: true,
validationSplit: this.task.trainingInformation.validationSplit
})
} catch (e) {
console.error(e)
if (e instanceof Error && e.message.includes('provided in columnConfigs does not match any of the column names')) {
// missing field is specified between two "quotes"
const missingFields: String = e.message.split('"')[1].split('"')[0]
toaster.error(`The input data is missing the field "${missingFields}"`)
} else {
toaster.error('Incorrect data format. Please check the expected format at the previous step.')
}
this.cleanState()
return
}
toaster.info('Model training started')
try {
this.startedTraining = true
for await (const roundLogs of this.disco.fit(dataset)) {
this.logs = this.logs.push(roundLogs)
}
this.startedTraining = false
} catch (e) {
toaster.error('An error occurred during training')
console.error(e)
this.cleanState()
return
}
toaster.success('Training successfully completed')
},
cleanState (): void {
this.distributedTraining = false
this.startedTraining = false
toaster.info("Model training started");
const scheme = distributed ? props.task.trainingInformation.scheme : "local";
const client =
scheme === "local"
? new clients.Local(
CONFIG.serverUrl,
props.task,
new aggregators.MeanAggregator(),
)
: new clients.federated.FederatedClient(
CONFIG.serverUrl,
props.task,
new aggregators.MeanAggregator(),
);
const disco = new Disco(props.task, {
logger: {
success: (msg: string) => {
messages.value = messages.value.push(msg);
},
error: (msg: string) => {
messages.value = messages.value.push(msg);
},
},
async pauseTraining (): Promise<void> {
await this.disco.pause()
this.startedTraining = false
memory: memoryStore.useIndexedDB ? new IndexedDB() : new EmptyMemory(),
scheme,
client,
});
try {
training.value = disco.fit(dataset);
for await (const roundLogs of training.value)
logs.value = logs.value.push(roundLogs);
if (training.value === undefined) {
toaster.success("Training stopped");
return;
}
} catch (e) {
toaster.error("An error occurred during training");
console.error(e);
} finally {
training.value = undefined;
}
})
toaster.success("Training successfully completed");
}
async function stopTraining(): Promise<void> {
const generator = training.value;
if (generator === undefined) return;
training.value = undefined;
generator.return();
}
</script>

0 comments on commit 76f189f

Please sign in to comment.