Skip to content

Commit

Permalink
Merge branch 'master' into naman-metrics-unit-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
namannandan authored Oct 10, 2023
2 parents cb1942c + 4d6dbe6 commit 5804690
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 36 deletions.
1 change: 1 addition & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ the backend workers convert "Bytearray to utf-8 string" when the Content-Type of
* `limit_max_image_pixels` : Default value is true (Use default [PIL.Image.MAX_IMAGE_PIXELS](https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.MAX_IMAGE_PIXELS)). If this is set to "false", set PIL.Image.MAX_IMAGE_PIXELS = None in backend default vision handler for large image payload.
* `allowed_urls` : Comma separated regex of allowed source URL(s) from where models can be registered. Default: `file://.*|http(s)?://.*` (all URLs and local file system)
e.g. : To allow base URLs `https://s3.amazonaws.com/` and `https://torchserve.pytorch.org/` use the following regex string `allowed_urls=https://s3.amazonaws.com/.*,https://torchserve.pytorch.org/.*`
* For security reason, `use_env_allowed_urls=true` is required in config.properties to read `allowed_urls` from environment variable.
* `workflow_store` : Path of workflow store directory. Defaults to model store directory.
* `disable_system_metrics` : Disable collection of system metrics when set to "true". Default value is "false".

Expand Down
3 changes: 2 additions & 1 deletion examples/large_models/tp_llama/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ def forward(
keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
#calling PT SDPA to enable using Flash Attention 2 and Xformer memory efficient kernels.
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, is_causal=True)
output = torch.nn.functional.scaled_dot_product_attention(xq.transpose(1,2), keys.transpose(1,2), values.transpose(1,2), attn_mask=mask, dropout_p=0.0, is_causal=False)

output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
import java.net.HttpURLConnection;
import java.net.URL;
import java.net.URLEncoder;
import java.nio.file.FileAlreadyExistsException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.pytorch.serve.archive.utils.ArchiveUtils;
import org.pytorch.serve.archive.utils.InvalidArchiveURLException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -19,20 +23,42 @@ public final class HttpUtils {
private HttpUtils() {}

/** Copy model from S3 url to local model store */
public static void copyURLToFile(URL endpointUrl, File modelLocation, boolean s3SseKmsEnabled)
throws IOException {
// for a simple GET, we have no body so supply the precomputed 'empty' hash
Map<String, String> headers;
if (s3SseKmsEnabled) {
String awsAccessKey = System.getenv("AWS_ACCESS_KEY_ID");
String awsSecretKey = System.getenv("AWS_SECRET_ACCESS_KEY");
String regionName = System.getenv("AWS_DEFAULT_REGION");
if (!regionName.isEmpty() && !awsAccessKey.isEmpty() && !awsSecretKey.isEmpty()) {
public static boolean copyURLToFile(
List<String> allowedUrls,
String url,
File modelLocation,
boolean s3SseKmsEnabled,
String archiveName)
throws FileAlreadyExistsException, IOException, InvalidArchiveURLException {
if (ArchiveUtils.validateURL(allowedUrls, url)) {
if (modelLocation.exists()) {
throw new FileAlreadyExistsException(archiveName);
}

if (archiveName.contains("/") || archiveName.contains("\\")) {
throw new IOException(
"Security alert slash or backslash appear in archiveName:" + archiveName);
}

// for a simple GET, we have no body so supply the precomputed 'empty' hash
Map<String, String> headers;
if (s3SseKmsEnabled) {
String awsAccessKey = System.getenv("AWS_ACCESS_KEY_ID");
String awsSecretKey = System.getenv("AWS_SECRET_ACCESS_KEY");
String regionName = System.getenv("AWS_DEFAULT_REGION");
if (regionName.isEmpty() || awsAccessKey.isEmpty() || awsSecretKey.isEmpty()) {
throw new IOException(
"Miss environment variables "
+ "AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY or AWS_DEFAULT_REGION");
}

HttpURLConnection connection = (HttpURLConnection) new URL(url).openConnection();
headers = new HashMap<>();
headers.put("x-amz-content-sha256", AWS4SignerBase.EMPTY_BODY_SHA256);

AWS4SignerForAuthorizationHeader signer =
new AWS4SignerForAuthorizationHeader(endpointUrl, "GET", "s3", regionName);
new AWS4SignerForAuthorizationHeader(
connection.getURL(), "GET", "s3", regionName);
String authorization =
signer.computeSignature(
headers,
Expand All @@ -44,7 +70,7 @@ public static void copyURLToFile(URL endpointUrl, File modelLocation, boolean s3
// place the computed signature into a formatted 'Authorization' header
// and call S3
headers.put("Authorization", authorization);
HttpURLConnection connection = createHttpConnection(endpointUrl, "GET", headers);
setHttpConnection(connection, "GET", headers);
try {
FileUtils.copyInputStreamToFile(connection.getInputStream(), modelLocation);
} finally {
Expand All @@ -53,28 +79,23 @@ public static void copyURLToFile(URL endpointUrl, File modelLocation, boolean s3
}
}
} else {
throw new IOException(
"Miss environment variables "
+ "AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY or AWS_DEFAULT_REGION");
URL endpointUrl = new URL(url);
FileUtils.copyURLToFile(endpointUrl, modelLocation);
}
} else {
FileUtils.copyURLToFile(endpointUrl, modelLocation);
}
return false;
}

public static HttpURLConnection createHttpConnection(
URL endpointUrl, String httpMethod, Map<String, String> headers) throws IOException {

HttpURLConnection connection = (HttpURLConnection) endpointUrl.openConnection();
public static void setHttpConnection(
HttpURLConnection connection, String httpMethod, Map<String, String> headers)
throws IOException {
connection.setRequestMethod(httpMethod);

if (headers != null) {
for (String headerKey : headers.keySet()) {
connection.setRequestProperty(headerKey, headers.get(headerKey));
}
}

return connection;
}

public static String urlEncode(String url, boolean keepPathSlash)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,14 @@ public static boolean downloadArchive(
boolean s3SseKmsEnabled)
throws FileAlreadyExistsException, FileNotFoundException, DownloadArchiveException,
InvalidArchiveURLException {
if (validateURL(allowedUrls, url)) {
if (location.exists()) {
throw new FileAlreadyExistsException(archiveName);
}
try {
HttpUtils.copyURLToFile(new URL(url), location, s3SseKmsEnabled);
} catch (IOException e) {
FileUtils.deleteQuietly(location);
throw new DownloadArchiveException("Failed to download archive from: " + url, e);
}
try {
return HttpUtils.copyURLToFile(
allowedUrls, url, location, s3SseKmsEnabled, archiveName);
} catch (InvalidArchiveURLException | FileAlreadyExistsException e) {
throw e;
} catch (IOException e) {
FileUtils.deleteQuietly(location);
throw new DownloadArchiveException("Failed to download archive from: " + url, e);
}

return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ public final class ConfigManager {

// Configuration default values
private static final String DEFAULT_TS_ALLOWED_URLS = "file://.*|http(s)?://.*";
private static final String USE_ENV_ALLOWED_URLS = "use_env_allowed_urls";

// Variables which are local
public static final String MODEL_METRICS_LOGGER = "MODEL_METRICS";
Expand Down Expand Up @@ -277,6 +278,14 @@ private void setSystemVars() {
Class<ConfigManager> configClass = ConfigManager.class;
Field[] fields = configClass.getDeclaredFields();
for (Field f : fields) {
// For security, disable TS_ALLOWED_URLS in env.
if ("TS_ALLOWED_URLS".equals(f.getName())
&& !"true"
.equals(
prop.getProperty(USE_ENV_ALLOWED_URLS, "false")
.toLowerCase())) {
continue;
}
if (f.getName().startsWith("TS_")) {
String val = System.getenv(f.getName());
if (val != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import org.apache.commons.io.FileUtils;
import org.pytorch.serve.archive.DownloadArchiveException;
import org.pytorch.serve.archive.model.Manifest;
import org.pytorch.serve.archive.model.ModelArchive;
Expand Down Expand Up @@ -238,7 +239,14 @@ private void setupModelDependencies(Model model)
null);

ProcessBuilder processBuilder = new ProcessBuilder(commandParts);
processBuilder.directory(model.getModelDir().getAbsoluteFile());
if (isValidDependencyPath(dependencyPath)) {
processBuilder.directory(dependencyPath);
} else {
throw new ModelException(
"Invalid 3rd party package installation path "
+ dependencyPath.getCanonicalPath());
}

Map<String, String> environment = processBuilder.environment();
for (String envVar : envp) {
String[] parts = envVar.split("=", 2);
Expand Down Expand Up @@ -274,6 +282,16 @@ private void setupModelDependencies(Model model)
}
}

private boolean isValidDependencyPath(File dependencyPath) {
if (dependencyPath
.toPath()
.normalize()
.startsWith(FileUtils.getTempDirectory().toPath().normalize())) {
return true;
}
return false;
}

private Model createModel(
ModelArchive archive,
int batchSize,
Expand Down

0 comments on commit 5804690

Please sign in to comment.