Skip to content

Commit

Permalink
Ugh I am likely going to sin, because I need to have Docker have my k…
Browse files Browse the repository at this point in the history
…eystore up...so I just wonder how to do this...
  • Loading branch information
AndrewQuijano committed Oct 23, 2023
1 parent fa726f2 commit 24b4ec8
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 58 deletions.
7 changes: 5 additions & 2 deletions .github/workflows/build-gradle-project.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ jobs:
ALIAS: andrew
KEYSTORE: andrew_keystore
PASSWORD: ${{ secrets.PASSWORD }}
CERTIFICATE: andrew_certificate

steps:
- name: Checkout project sources
Expand All @@ -26,7 +25,11 @@ jobs:
java-version: '17'
cache: 'gradle'

- run: sh gradlew build
- name: Create keystore for local use
run: sh create_keystore.sh

- name: Run Gradle Testing
run: sh gradlew build

- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v3
Expand Down
File renamed without changes.
6 changes: 3 additions & 3 deletions src/main/java/weka/finito/client.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import static weka.finito.utils.shared.*;

public final class client implements Runnable {
private final SSLSocketFactory factory = (SSLSocketFactory) SSLSocketFactory.getDefault();

private final String classes_file = "classes.txt";
private final String features_file;
private final int key_size;
Expand Down Expand Up @@ -204,7 +206,7 @@ private boolean need_keys() {
}

private String [] read_classes() {
// Don't forget to remember the classes of DT as well
// Remember the classes of DT as well
StringBuilder content = new StringBuilder();
String line;

Expand Down Expand Up @@ -397,8 +399,6 @@ else if (comparison_type == 1) {
public void run() {

// Step: 1
SSLSocketFactory factory = (SSLSocketFactory) SSLSocketFactory.getDefault();

boolean talk_to_server_site = this.need_keys();

try {
Expand Down
11 changes: 4 additions & 7 deletions src/main/java/weka/finito/level_site_thread.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
import java.io.ObjectOutputStream;
import java.net.Socket;
import java.util.Hashtable;
import java.util.List;
import java.util.Map;

import static weka.finito.utils.shared.compare;
import static weka.finito.utils.shared.traverse_level;

public class level_site_thread implements Runnable {
Expand All @@ -32,8 +30,6 @@ public class level_site_thread implements Runnable {
public level_site_thread(Socket client_socket, level_order_site level_site_data, AES crypto) {
this.client_socket = client_socket;
this.crypto = crypto;
String clientIpAddress = client_socket.getInetAddress().getHostAddress();
System.out.println("Level-Site got connection from client: " + clientIpAddress);

Object x;
try {
Expand All @@ -42,7 +38,7 @@ public level_site_thread(Socket client_socket, level_order_site level_site_data,

x = fromClient.readObject();
if (x instanceof level_order_site) {
// Traffic from Server. Level-Site alone will manage closing this.
// Traffic from Server. Level-Site alone will manage to close this.
this.level_site_data = (level_order_site) x;
// System.out.println("Level-Site received training data on Port: " + client_socket.getLocalPort());
this.toClient.writeBoolean(true);
Expand Down Expand Up @@ -115,7 +111,7 @@ public final void run() {
}

// Null, keep going down the tree,
// Not-null, you got the correct leaf node of your DT!
// Not null, you got the correct leaf node of your DT!
NodeInfo reply = traverse_level(level_site_data, encrypted_features, toClient, niu);

String encrypted_next_index;
Expand All @@ -130,8 +126,9 @@ public final void run() {
toClient.writeObject(reply.getVariableName());
}
else {

toClient.writeBoolean(false);
// encrypt with AES, send to client which will send to next level-site
// encrypt with AES, send to the client which will send to next level-site
encrypted_next_index = crypto.encrypt(String.valueOf(this.level_site_data.get_next_index()));
iv = crypto.getIV();
toClient.writeObject(encrypted_next_index);
Expand Down
83 changes: 37 additions & 46 deletions src/main/java/weka/finito/server.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
public final class server implements Runnable {

private final SSLServerSocketFactory factory = (SSLServerSocketFactory) SSLServerSocketFactory.getDefault();
private final SSLSocketFactory socket_factory = (SSLSocketFactory) SSLSocketFactory.getDefault();
private static final String os = System.getProperty("os.name").toLowerCase();
private final String training_data;
private final String [] level_site_ips;
Expand All @@ -49,14 +50,13 @@ public final class server implements Runnable {
private PaillierPublicKey paillier_public;
private DGKPublicKey dgk_public;
private final int precision;
private ClassifierTree ppdt;
private ClassifierTree ppdt = null;
private final List<String> leaves = new ArrayList<>();
private final List<level_order_site> all_level_sites = new ArrayList<>();

private final int server_port;
private final boolean use_level_sites;

private String client;
private int evaluations = 1;

public static void main(String[] args) {
int port = 0;
Expand Down Expand Up @@ -99,9 +99,6 @@ else if (args.length == 2){
}
String[] level_domains = level_domains_str.split(",");

// Figure out Client IP for last level-site to send back
String client = System.getenv("CLIENT");

// Create and run the server.
System.out.println("Server Initialized and started running");

Expand All @@ -111,7 +108,7 @@ else if (args.length == 2){
server = new server(training_data, level_domains, port, precision, port);
}
else {
server = new server(training_data, precision, port, client);
server = new server(training_data, precision, port);
}
server.run();
}
Expand All @@ -124,50 +121,44 @@ public server(String training_data, String [] level_site_ips, int [] level_site_
this.level_site_ports = level_site_ports;
this.precision = precision;
this.server_port = server_port;
this.use_level_sites = true;
}

// For Cloud environment, (Testing with Kubernetes, for level-sites)
// For Cloud environment, (Testing with Kubernetes/EKS, for level-sites)
public server(String training_data, String [] level_site_domains, int port, int precision, int server_port) {
this.training_data = training_data;
this.level_site_ips = level_site_domains;
this.port = port;
this.precision = precision;
this.server_port = server_port;
this.use_level_sites = true;
// I will likely want more than 1 test with server-site!
this.evaluations = 100;
}

// For testing, but having just a client and server
// For testing, but having just a client and server, just do one evaluation for the sake of testing.
public server(String training_data, int precision, int server_port) {
this.training_data = training_data;
this.level_site_ips = null;
this.precision = precision;
this.server_port = server_port;
this.use_level_sites = false;
}

public server(String training_data, int precision, int server_port, String client) {
this.training_data = training_data;
this.level_site_ips = null;
this.precision = precision;
this.server_port = server_port;
this.use_level_sites = false;
this.client = client;
}

private void run_one_time(int port) throws IOException, HomomorphicException, ClassNotFoundException {
try (SSLServerSocket serverSocket = (SSLServerSocket) factory.createServerSocket(port);) {
private void run_server_site(int port) throws IOException, HomomorphicException, ClassNotFoundException {
int count = 0;
try (SSLServerSocket serverSocket = (SSLServerSocket) factory.createServerSocket(port)) {
serverSocket.setEnabledProtocols(protocols);
serverSocket.setEnabledCipherSuites(cipher_suites);

System.out.println("Server will be waiting for direct evaluation from client");
try (SSLSocket client_site = (SSLSocket) serverSocket.accept()) {
evaluate(client_site);
while (count < evaluations) {
try (SSLSocket client_site = (SSLSocket) serverSocket.accept()) {
evaluate_with_client_directly(client_site);
}
++count;
}
}
}

private void evaluate(SSLSocket client_site)
private void evaluate_with_client_directly(SSLSocket client_site)
throws IOException, HomomorphicException, ClassNotFoundException {

ObjectOutputStream to_client_site = new ObjectOutputStream(client_site.getOutputStream());
Expand Down Expand Up @@ -468,37 +459,39 @@ else if (node_info.comparisonType == 6) {
// Evaluate, be prepared for either level-site or no level-site case, 1 time
public void run() {

// Step 1
SSLSocketFactory factory = (SSLSocketFactory) SSLSocketFactory.getDefault();

try {
// Train the DT
ppdt = train_decision_tree(this.training_data);
// Get Public Keys from Client AND train level-sites
client_communication();
// Train the DT if you have to.
if (ppdt == null) {
ppdt = train_decision_tree(this.training_data);
// Get Public Keys from Client AND train level-sites
client_communication();
}
}
catch (Exception e) {
throw new RuntimeException(e);
}

ObjectOutputStream to_level_site;
ObjectInputStream from_level_site;
int connection_port;

// If we are testing without level-sites do this...
if (!use_level_sites) {
if (this.level_site_ips != null) {
train_level_sites();
} else {
try {
run_one_time(this.server_port);
run_server_site(this.server_port);
}
catch (IOException | HomomorphicException | ClassNotFoundException e) {
throw new RuntimeException(e);
}
return;
}
}

private void train_level_sites() {
ObjectOutputStream to_level_site;
ObjectInputStream from_level_site;
int connection_port;

// There should be at least 1 IP Address for each level site
assert this.level_site_ips != null;
if(this.level_site_ips.length < all_level_sites.size()) {
assert this.level_site_ips != null;
if(this.level_site_ips.length < all_level_sites.size()) {
String error = String.format("Please create more level-sites for the " +
"decision tree trained from %s", training_data);
throw new RuntimeException(error);
Expand All @@ -518,12 +511,10 @@ public void run() {
if (i + 1 != all_level_sites.size()) {
current_level_site.set_next_level_site(level_site_ips[(i + 1) % level_site_ips.length]);
}
else {
current_level_site.set_next_level_site(client);
}

current_level_site.set_next_level_site_port(connection_port);

try(SSLSocket level_site = (SSLSocket) factory.createSocket(level_site_ips[i], connection_port)) {
try(SSLSocket level_site = (SSLSocket) socket_factory.createSocket(level_site_ips[i], connection_port)) {
// Step: 3
level_site.setEnabledProtocols(protocols);
level_site.setEnabledCipherSuites(cipher_suites);
Expand Down

0 comments on commit 24b4ec8

Please sign in to comment.