Skip to content

Commit

Permalink
unified html
Browse files Browse the repository at this point in the history
  • Loading branch information
AmanPriyanshu committed Nov 24, 2024
1 parent 1222505 commit abcd264
Showing 1 changed file with 93 additions and 81 deletions.
174 changes: 93 additions & 81 deletions index.html
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
<meta charset="UTF-8">
<title>Federated Learning Interactive Game</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<!-- Embedded CSS -->
<style>
body {
font-family: Arial, sans-serif;
Expand All @@ -30,7 +29,7 @@
border-radius: 5px;
margin: 10px;
padding: 10px;
flex: 1 1 calc(50% - 40px); /* Adjust for margins */
flex: 1 1 calc(50% - 40px);
max-width: calc(50% - 40px);
}
@media (min-width: 768px) {
Expand Down Expand Up @@ -72,15 +71,36 @@
text-align: center;
}
.notification {
background-color: #ff9800;
padding: 10px;
margin: 10px;
border-radius: 5px;
position: fixed;
top: 20px;
left: 50%;
transform: translateX(-50%);
z-index: 1000;
background-color: #ff9800;
padding: 15px 25px;
margin: 10px;
border-radius: 5px;
box-shadow: 0 2px 5px rgba(0,0,0,0.2);
animation: slideDown 0.3s ease-out;
}

@keyframes slideDown {
from {
transform: translate(-50%, -100%);
opacity: 0;
}
to {
transform: translate(-50%, 0);
opacity: 1;
}
}
</style>
</head>
<body>

<div class="notification" id="notification" style="display: none;">
</div>

<header>
<h1>Federated Learning Interactive Game</h1>
</header>
Expand All @@ -94,7 +114,6 @@ <h1>Federated Learning Interactive Game</h1>
</div>

<div class="container" id="client-container">
<!-- Client panels will be injected here -->
</div>

<div class="global-controls">
Expand All @@ -113,32 +132,24 @@ <h1>Federated Learning Interactive Game</h1>
<canvas id="global-accuracy-plot"></canvas>
</div>

<div class="notification" id="notification" style="display: none;">
<!-- Notifications will appear here -->
</div>

<footer>
<p>Created for the 30 Days of PETs Challenge</p>
<p>Created for the 30 Days of FL Challenge</p>
</footer>

<!-- Embedded JS -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
<script>
// Initialize global variables
const NUM_CLIENTS = 5;
let clients = [];
let serverModel;
let globalTestData = { x: null, y: null };
let dataDistribution = 'iid';
let dropoutProbability = 0;

// Initialize the application
document.addEventListener('DOMContentLoaded', async () => {
createClientPanels();
await loadGlobalTestData();
initializeServerModel();
await reloadClientData(); // Ensure client data is loaded initially
await reloadClientData();
setupEventListeners();
initializeGlobalPlot();
});
Expand All @@ -155,14 +166,14 @@ <h3>Client ${i}</h3>
<div class="controls">
<div class="slider-label">
<label for="lr-${i}">Learning Rate:</label>
<span id="lr-value-${i}">0.01</span>
<span id="lr-value-${i}">0.00005</span>
</div>
<input type="range" id="lr-${i}" min="0.001" max="0.1" value="0.01" step="0.001">
<input type="range" id="lr-${i}" min="0.00001" max="0.001" value="0.00005" step="0.00001">
<div class="slider-label">
<label for="epochs-${i}">Epochs:</label>
<span id="epochs-value-${i}">1</span>
<span id="epochs-value-${i}">5</span>
</div>
<input type="range" id="epochs-${i}" min="1" max="10" value="1">
<input type="range" id="epochs-${i}" min="5" max="15" value="1">
</div>
<button id="train-local-${i}">Train Local</button>
<canvas id="accuracy-plot-${i}" width="200" height="150"></canvas>
Expand All @@ -176,81 +187,68 @@ <h3>Client ${i}</h3>
plot: null,
});
}
console.log("Created the client panels");
}

async function loadGlobalTestData() {
// Load global test data by combining clients' test data
globalTestData.x = [];
globalTestData.y = [];
for (let i = 1; i <= NUM_CLIENTS; i++) {
const testData = await loadClientData(i, 'test');
globalTestData.x.push(...testData.x);
globalTestData.y.push(...testData.y);
}
// Convert to tensors
globalTestData.x = tf.tensor2d(globalTestData.x, undefined, 'float32').div(tf.scalar(255));
globalTestData.y = tf.tensor1d(globalTestData.y, 'int32');
globalTestData.y = tf.tensor1d(globalTestData.y, 'float32');
console.log("Loaded globalTestData");
}

function initializeServerModel() {
// Create a simple MLP model
serverModel = createModel();
}

function setupEventListeners() {
// Data selection change
document.getElementById('data-selection').addEventListener('change', async (e) => {
dataDistribution = e.target.value;
await reloadClientData();
});

// Dropout probability slider
document.getElementById('dropout-prob').addEventListener('input', (e) => {
dropoutProbability = parseInt(e.target.value);
document.getElementById('dropout-value').innerText = `${dropoutProbability}%`;
});

// Train all clients
document.getElementById('run-all').addEventListener('click', () => {
clients.forEach((client) => trainLocalModel(client.id));
});

// Aggregate models
document.getElementById('aggregate').addEventListener('click', aggregateModels);

// Client-specific event listeners
clients.forEach((client) => {
// Learning rate slider
document.getElementById(`lr-${client.id}`).addEventListener('input', (e) => {
const lr = parseFloat(e.target.value);
document.getElementById(`lr-value-${client.id}`).innerText = lr.toFixed(3);
});

// Epochs slider
document.getElementById(`epochs-${client.id}`).addEventListener('input', (e) => {
const epochs = parseInt(e.target.value);
document.getElementById(`epochs-value-${client.id}`).innerText = epochs;
});

// Train local model
document.getElementById(`train-local-${client.id}`).addEventListener('click', () => {
trainLocalModel(client.id);
});
});
}

async function reloadClientData() {
// Reload data for each client
for (let client of clients) {
client.data = await loadClientData(client.id, 'train');
}
}

async function loadClientData(clientId, dataType) {
const response = await fetch(`data/${dataDistribution}/client${clientId}.json`);
const response = await fetch(`https://amanpriyanshu.github.io/FL-JS-MNIST/data/${dataDistribution}/client${clientId}.json`);
const data = await response.json();
const x = dataType === 'train' ? data.x_train : data.x_test;
const y = dataType === 'train' ? data.y_train : data.y_test;
console.log("Loaded client data", clientId);
return {
x: x,
y: y,
Expand All @@ -266,64 +264,62 @@ <h3>Client ${i}</h3>
loss: 'sparseCategoricalCrossentropy',
metrics: ['accuracy'],
});
console.log("Created client model");
return model;
}

async function trainLocalModel(clientId) {
const client = clients.find((c) => c.id === clientId);

// Simulate communication dropout
if (Math.random() < dropoutProbability / 100) {
showNotification(`Client ${clientId} dropped out during communication.`);
return;
}

// Load data if not already loaded
if (!client.data) {
client.data = await loadClientData(clientId, 'train');
}

// Get training parameters
const lr = parseFloat(document.getElementById(`lr-${clientId}`).value);
const epochs = parseInt(document.getElementById(`epochs-${clientId}`).value);

// Create a new model and set its weights to the server model's weights
client.model = createModel();
client.model.setWeights(serverModel.getWeights());
console.log("set weights to server model");

// Compile the model with the client's learning rate
client.model.compile({
optimizer: tf.train.adam(lr),
loss: 'sparseCategoricalCrossentropy',
metrics: ['accuracy'],
});

// Convert data to tensors
const xTrain = tf.tensor2d(client.data.x, undefined, 'float32').div(tf.scalar(255));
const yTrain = tf.tensor1d(client.data.y, 'int32');
const xTrain = tf.tensor2d(client.data.x, [1000, 784], 'float32').div(tf.scalar(255.0));
const yTrain = tf.tensor1d(client.data.y, 'float32');
console.log("Starting training with xTrain dtype and shape:", xTrain.dtype, xTrain.shape, "and yTrain dtype:", yTrain.dtype, yTrain.shape);
console.log("epochs and lr", epochs, lr)

// Train the model
await client.model.fit(xTrain, yTrain, {
epochs: epochs,
verbose: 0,
callbacks: {
onEpochEnd: async (epoch, logs) => {
// Update local accuracy plot
const acc = logs.acc;
console.log(`Epoch ${epoch + 1} accuracy: ${(acc * 100).toFixed(2)}%`);
client.accuracyHistory.push(acc);
updateClientPlot(clientId);
},
},
});
});

console.log("trained model");

// Clean up tensors
xTrain.dispose();
yTrain.dispose();

// Evaluate on the local test data
const testData = await loadClientData(clientId, 'test');
const xTest = tf.tensor2d(testData.x, undefined, 'float32').div(tf.scalar(255));
const yTest = tf.tensor1d(testData.y, 'int32');
const yTest = tf.tensor1d(testData.y, 'float32');
console.log("Loaded xTest and yTest");
const evalResult = client.model.evaluate(xTest, yTest, { verbose: 0 });
const testAcc = (await evalResult[1].data())[0];
showNotification(`Client ${clientId} Test Accuracy: ${(testAcc * 100).toFixed(2)}%`);
Expand All @@ -338,13 +334,11 @@ <h3>Client ${i}</h3>
return model;
}

// Update the cloneModel function
function cloneModel(model) {
return model.clone();
}

async function aggregateModels() {
// Collect models from clients
const clientModels = clients
.filter((client) => client.model)
.map((client) => client.model);
Expand All @@ -354,11 +348,14 @@ <h3>Client ${i}</h3>
return;
}

// Average the weights
const averagedWeights = averageWeights(clientModels);
serverModel.setWeights(averagedWeights);
serverModel.compile({
optimizer: tf.train.adam(0.001),
loss: 'sparseCategoricalCrossentropy',
metrics: ['accuracy'],
});

// Evaluate on global test data
const evalResult = serverModel.evaluate(globalTestData.x, globalTestData.y, { verbose: 0 });
const globalAcc = (await evalResult[1].data())[0];
updateGlobalPlot(globalAcc);
Expand Down Expand Up @@ -412,32 +409,47 @@ <h3>Client ${i}</h3>
function updateClientPlot(clientId) {
const client = clients.find((c) => c.id === clientId);
if (!client.plot) {
const ctx = document.getElementById(`accuracy-plot-${clientId}`).getContext('2d');
client.plot = new Chart(ctx, {
type: 'line',
data: {
labels: client.accuracyHistory.map((_, i) => i + 1),
datasets: [{
label: `Client ${clientId} Accuracy`,
data: client.accuracyHistory,
borderColor: '#2196f3',
fill: false,
}],
},
options: {
responsive: true,
scales: {
x: { display: true, title: { display: true, text: 'Epochs' } },
y: { display: true, title: { display: true, text: 'Accuracy' }, min: 0, max: 1 },
},
},
});
const ctx = document.getElementById(`accuracy-plot-${clientId}`).getContext('2d');
client.plot = new Chart(ctx, {
type: 'line',
data: {
labels: Array.from({length: client.accuracyHistory.length}, (_, i) => i + 1),
datasets: [{
label: `Client ${clientId} Accuracy`,
data: client.accuracyHistory,
borderColor: '#2196f3',
fill: false,
}],
},
options: {
responsive: true,
scales: {
x: {
type: 'linear',
display: true,
title: { display: true, text: 'Epochs' },
ticks: {
stepSize: 1,
callback: function(value) {
return Math.floor(value);
}
}
},
y: {
display: true,
title: { display: true, text: 'Accuracy' },
min: 0,
max: 1
},
},
},
});
} else {
client.plot.data.labels.push(client.accuracyHistory.length);
client.plot.data.datasets[0].data.push(client.accuracyHistory[client.accuracyHistory.length - 1]);
client.plot.update();
client.plot.data.labels = Array.from({length: client.accuracyHistory.length}, (_, i) => i + 1);
client.plot.data.datasets[0].data = client.accuracyHistory;
client.plot.update();
}
}
}

function showNotification(message) {
const notification = document.getElementById('notification');
Expand Down

0 comments on commit abcd264

Please sign in to comment.