Skip to content

Commit

Permalink
added torch check created const for ENV_VARS_TRUE_VALUES removed sett…
Browse files Browse the repository at this point in the history
…ing values due to default values
  • Loading branch information
philschmid committed Mar 17, 2021
1 parent a7b25e6 commit eb69503
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions src/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,30 @@
else:
import importlib.metadata as importlib_metadata

ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})

USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()

TORCH_VERSION = "N/A"
TORCH_AVAILABLE = False
if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"):
try:
TORCH_VERSION = importlib_metadata.version("torch")
TORCH_AVAILABLE = True
logger.info("PyTorch version {} available.".format(TORCH_VERSION))
except importlib_metadata.PackageNotFoundError:
pass

if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
TORCH_AVAILABLE = importlib.util.find_spec("torch") is not None
if TORCH_AVAILABLE:
try:
TORCH_VERSION = importlib_metadata.version("torch")
logger.info(f"PyTorch version {TORCH_VERSION} available.")
except importlib_metadata.PackageNotFoundError:
pass
else:
logger.info("Disabling PyTorch because USE_TF is set")

TF_VERSION = "N/A"
TF_AVAILABLE = False

if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"):
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
TF_AVAILABLE = importlib.util.find_spec("tensorflow") is not None
if TF_AVAILABLE:
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
Expand All @@ -71,8 +75,7 @@
try:
TF_VERSION = importlib_metadata.version("tf-nightly-gpu")
except importlib_metadata.PackageNotFoundError:
TF_VERSION = None
TF_AVAILABLE = False
pass
if TF_AVAILABLE:
if version.parse(TF_VERSION) < version.parse("2"):
logger.info(f"TensorFlow found but with version {TF_VERSION}. Transformers requires version 2 minimum.")
Expand All @@ -81,7 +84,6 @@
logger.info(f"TensorFlow version {TF_VERSION} available.")
else:
logger.info("Disabling Tensorflow because USE_TORCH is set")
TF_AVAILABLE = False



Expand Down

1 comment on commit eb69503

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Show benchmarks

PyArrow==0.17.1

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.021534 / 0.011353 (0.010181) 0.018033 / 0.011008 (0.007025) 0.049011 / 0.038508 (0.010503) 0.042247 / 0.023109 (0.019138) 0.225288 / 0.275898 (-0.050610) 0.253425 / 0.323480 (-0.070054) 0.006408 / 0.007986 (-0.001577) 0.004658 / 0.004328 (0.000330) 0.006994 / 0.004250 (0.002744) 0.046812 / 0.037052 (0.009760) 0.220272 / 0.258489 (-0.038217) 0.258752 / 0.293841 (-0.035089) 0.174227 / 0.128546 (0.045681) 0.137784 / 0.075646 (0.062138) 0.479473 / 0.419271 (0.060201) 0.692595 / 0.043533 (0.649062) 0.219978 / 0.255139 (-0.035161) 0.249301 / 0.283200 (-0.033899) 2.493747 / 0.141683 (2.352064) 1.974604 / 1.452155 (0.522449) 1.991036 / 1.492716 (0.498320)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.044914 / 0.037411 (0.007503) 0.021190 / 0.014526 (0.006664) 0.039632 / 0.176557 (-0.136924) 0.046700 / 0.737135 (-0.690435) 0.027774 / 0.296338 (-0.268565)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.260266 / 0.215209 (0.045057) 2.632115 / 2.077655 (0.554461) 1.381500 / 1.504120 (-0.122620) 1.232183 / 1.541195 (-0.309012) 1.230701 / 1.468490 (-0.237790) 7.672714 / 4.584777 (3.087937) 6.886960 / 3.745712 (3.141248) 9.653912 / 5.269862 (4.384050) 8.400753 / 4.565676 (3.835077) 0.755322 / 0.424275 (0.331047) 0.012362 / 0.007607 (0.004755) 0.312585 / 0.226044 (0.086540) 3.269205 / 2.268929 (1.000276) 1.974789 / 55.444624 (-53.469836) 1.677245 / 6.876477 (-5.199232) 1.652555 / 2.142072 (-0.489517) 7.937021 / 4.805227 (3.131794) 6.529247 / 6.500664 (0.028583) 9.144330 / 0.075469 (9.068861)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 12.453551 / 1.841788 (10.611764) 13.624655 / 8.074308 (5.550347) 23.579845 / 10.191392 (13.388453) 0.963894 / 0.680424 (0.283470) 0.318595 / 0.534201 (-0.215605) 0.875418 / 0.579283 (0.296135) 0.695874 / 0.434364 (0.261510) 0.802281 / 0.540337 (0.261943) 1.695754 / 1.386936 (0.308818)
PyArrow==1.0
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.019307 / 0.011353 (0.007954) 0.016761 / 0.011008 (0.005753) 0.049801 / 0.038508 (0.011293) 0.033476 / 0.023109 (0.010367) 0.347756 / 0.275898 (0.071858) 0.387276 / 0.323480 (0.063796) 0.006621 / 0.007986 (-0.001364) 0.005068 / 0.004328 (0.000739) 0.009467 / 0.004250 (0.005216) 0.052910 / 0.037052 (0.015858) 0.337857 / 0.258489 (0.079368) 0.396245 / 0.293841 (0.102404) 0.170845 / 0.128546 (0.042298) 0.145335 / 0.075646 (0.069689) 0.437335 / 0.419271 (0.018064) 0.482606 / 0.043533 (0.439073) 0.362359 / 0.255139 (0.107220) 0.381247 / 0.283200 (0.098047) 1.979942 / 0.141683 (1.838259) 2.048183 / 1.452155 (0.596028) 2.053453 / 1.492716 (0.560737)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.046789 / 0.037411 (0.009377) 0.023725 / 0.014526 (0.009199) 0.038969 / 0.176557 (-0.137588) 0.054001 / 0.737135 (-0.683134) 0.038601 / 0.296338 (-0.257737)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.345620 / 0.215209 (0.130410) 3.462922 / 2.077655 (1.385268) 2.173180 / 1.504120 (0.669060) 2.043735 / 1.541195 (0.502540) 2.030126 / 1.468490 (0.561636) 7.725060 / 4.584777 (3.140283) 6.863078 / 3.745712 (3.117365) 9.476038 / 5.269862 (4.206176) 8.312368 / 4.565676 (3.746691) 0.760776 / 0.424275 (0.336501) 0.011799 / 0.007607 (0.004192) 0.377777 / 0.226044 (0.151733) 3.991544 / 2.268929 (1.722615) 2.692559 / 55.444624 (-52.752065) 2.376052 / 6.876477 (-4.500424) 2.471869 / 2.142072 (0.329797) 7.649065 / 4.805227 (2.843838) 6.031654 / 6.500664 (-0.469010) 9.297989 / 0.075469 (9.222519)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 12.675116 / 1.841788 (10.833328) 13.935349 / 8.074308 (5.861041) 23.200090 / 10.191392 (13.008698) 0.833563 / 0.680424 (0.153139) 0.618351 / 0.534201 (0.084150) 0.862895 / 0.579283 (0.283612) 0.670236 / 0.434364 (0.235872) 0.749831 / 0.540337 (0.209493) 1.648524 / 1.386936 (0.261588)

CML watermark

Please sign in to comment.