Skip to content

Commit

Permalink
Merge pull request #1174 from beehive-lab/florin/enable-thread-safe-g…
Browse files Browse the repository at this point in the history
…etAllDevices

Make {PTX/SPIRV/OCL}BackendImpl::getAllDevices thread-safe
  • Loading branch information
jjfumero authored Nov 14, 2024
2 parents 73d1dfb + e7a37bc commit 0d44252
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public final class OCLBackendImpl implements TornadoAcceleratorBackend {
private final OCLBackend[][] backends;
private final List<OCLContextInterface> contexts;
private OCLBackend[] flatBackends;
private List<TornadoDevice> devices;
private volatile List<TornadoDevice> devices;
private final TornadoLogger logger;

public OCLBackendImpl(final OptionValues options, final HotSpotJVMCIRuntime vmRuntime, TornadoVMConfigAccess vmConfig) {
Expand Down Expand Up @@ -167,9 +167,13 @@ public TornadoXPUDevice getDevice(int index) {
@Override
public List<TornadoDevice> getAllDevices() {
if (devices == null) {
devices = new ArrayList<>();
for (int deviceIndex = 0; deviceIndex < getNumDevices(); deviceIndex++) {
devices.add(getDevice(deviceIndex));
synchronized (this) {
if (devices == null) {
devices = new ArrayList<>();
for (int deviceIndex = 0; deviceIndex < getNumDevices(); deviceIndex++) {
devices.add(getDevice(deviceIndex));
}
}
}
}
return devices;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public final class PTXBackendImpl implements TornadoAcceleratorBackend {

private final PTXBackend[] backends;
private final TornadoLogger logger;
private List<TornadoDevice> devices;
private volatile List<TornadoDevice> devices;

public PTXBackendImpl(final OptionValues options, final HotSpotJVMCIRuntime vmRuntime, TornadoVMConfigAccess vmConfig) {

Expand Down Expand Up @@ -132,9 +132,13 @@ public TornadoXPUDevice getDevice(int index) {
@Override
public List<TornadoDevice> getAllDevices() {
if (devices == null) {
devices = new ArrayList<>();
for (int i = 0; i < getNumDevices(); i++) {
devices.add(backends[i].getDeviceContext().toDevice());
synchronized (this) {
if (devices == null) {
devices = new ArrayList<>();
for (int i = 0; i < getNumDevices(); i++) {
devices.add(backends[i].getDeviceContext().toDevice());
}
}
}
}
return devices;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public final class SPIRVBackendImpl implements TornadoAcceleratorBackend {
* devices).
*/
private int backendCounter;
private List<TornadoDevice> devices;
private volatile List<TornadoDevice> devices;

public SPIRVBackendImpl(OptionValues options, HotSpotJVMCIRuntime vmRuntime, TornadoVMConfigAccess vmConfigAccess) {
int numPlatforms = SPIRVRuntimeImpl.getInstance().getNumPlatforms();
Expand Down Expand Up @@ -190,9 +190,13 @@ public TornadoXPUDevice getDevice(int index) {
@Override
public List<TornadoDevice> getAllDevices() {
if (devices == null) {
devices = new ArrayList<>();
for (int i = 0; i < getNumDevices(); i++) {
devices.add(flatBackends[i].getDeviceContext().toDevice());
synchronized (this) {
if (devices == null) {
devices = new ArrayList<>();
for (int i = 0; i < getNumDevices(); i++) {
devices.add(flatBackends[i].getDeviceContext().toDevice());
}
}
}
}
return devices;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Copyright (c) 2024, APT Group, Department of Computer Science,
* The University of Manchester.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package uk.ac.manchester.tornado.unittests.functional;

import org.junit.Test;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.annotations.Parallel;
import uk.ac.manchester.tornado.api.enums.DataTransferMode;
import uk.ac.manchester.tornado.api.exceptions.TornadoExecutionPlanException;
import uk.ac.manchester.tornado.api.types.arrays.DoubleArray;

import java.util.ArrayList;
import java.util.HashSet;
import static org.junit.Assert.fail;

public class TestStreams {

private static void deviceDummyCompute(final DoubleArray src, final DoubleArray dst) {
for (@Parallel int i = 0; i < src.getSize(); i++) {
dst.set(i, src.get(i) * 2);
}
}

private static void hostComputeMethod() throws TornadoExecutionPlanException {
DoubleArray src = new DoubleArray(1024);
DoubleArray dst = new DoubleArray(1024);
String threadName = Thread.currentThread().getName();

TaskGraph taskGraph = new TaskGraph("s1") //
.transferToDevice(DataTransferMode.EVERY_EXECUTION, src, dst) //
.task(threadName, TestStreams::deviceDummyCompute, src, dst) //
.transferToHost(DataTransferMode.EVERY_EXECUTION, dst);

ImmutableTaskGraph immutableTaskGraph = taskGraph.snapshot();
try (TornadoExecutionPlan executor = new TornadoExecutionPlan(immutableTaskGraph)) {
executor.execute();
}
}

@Test
public void testParallelStreams() throws TornadoExecutionPlanException {
ArrayList<Integer> s = new ArrayList<>();
for(int i = 0; i < 512; i++) {
s.add(i);
}

s.parallelStream().forEach(k->{
try {
hostComputeMethod();
} catch (TornadoExecutionPlanException e) {
fail(STR."Got exception \{e}");
}
});
}

}

0 comments on commit 0d44252

Please sign in to comment.