From 1d7169f0097b52f503b0e8fc0b222ce19ee222e4 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sat, 2 May 2020 16:07:20 +0900 Subject: [PATCH 1/6] Add code format configurations --- Makefile | 9 + Pipfile | 5 + Pipfile.lock | 525 +++++++++++++++++++++++++++++++-------------------- README.rst | 28 ++- setup.cfg | 5 + tox.ini | 4 +- 6 files changed, 370 insertions(+), 206 deletions(-) create mode 100644 Makefile diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..104736ba --- /dev/null +++ b/Makefile @@ -0,0 +1,9 @@ +.PHONY: fmt +fmt: + pipenv run isort -rc . + pipenv run black . + +.PHONY: chk +chk: + pipenv run isort -c -rc . + pipenv run black --check --diff . diff --git a/Pipfile b/Pipfile index 4b40b13f..f530cd51 100644 --- a/Pipfile +++ b/Pipfile @@ -3,6 +3,9 @@ url = "https://pypi.python.org/simple" verify_ssl = true name = "pypi" +[pipenv] +allow_prereleases = true + [packages] e1839a8 = {path = ".",extras = ["sqlalchemy", "pandas"],editable = true} @@ -14,4 +17,6 @@ wheel = "*" pytest = ">=3.5" pytest-cov = "*" pytest-flake8 = ">=1.0.1" +pytest-black = "*" +pytest-isort = "*" pytest-xdist = "*" diff --git a/Pipfile.lock b/Pipfile.lock index bbb88e45..9fedaaeb 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "a2e21b41ea04396f4ae681dd6b7db437d454d2d912e5c22f5f732d024148862e" + "sha256": "d58bd09a928275f86cb6f3f4787f14242d35317998bad2b3cb006a23626080e7" }, "pipfile-spec": 6, "requires": {}, @@ -16,17 +16,17 @@ "default": { "boto3": { "hashes": [ - "sha256:57398de1b5e074e715c866441e69f90c9468959d5743a021d8aeed04fbaa1078", - "sha256:60ac1124597231ed36a7320547cd0d16a001bb92333ab30ad20514f77e585225" + "sha256:6374e39ea66433250ac6e4726623fe59ec32bfd52e5c871140c883ab72b0375c", + "sha256:7a2474db6576d7d3f5c3336ec54450e34211d44e2342e501a67e2fae35916e63" ], - "version": "==1.12.32" + "version": "==1.13.1" }, "botocore": { "hashes": [ - "sha256:3ea89601ee452b65084005278bd832be854cfde5166685dcb14b6c8f19d3fc6d", - "sha256:a963af564d94107787ff3d2c534e8b7aed7f12e014cdd609f8fcb17bf9d9b19a" + "sha256:879cedb22baf9446323240f1cf57d4e0e7ba262ba6fde6d3540cf7fdd7ddad34", + "sha256:c6d3dd976c1462a4e0f2dabde09c38de2a641e6b6bb0af505ac8465735796367" ], - "version": "==1.15.32" + "version": "==1.16.1" }, "docutils": { "hashes": [ @@ -50,14 +50,6 @@ ], "version": "==0.18.2" }, - "futures": { - "hashes": [ - "sha256:3a44f286998ae64f0cc083682fcfec16c406134a81a589a5de445d7bb7c2751b", - "sha256:51ecb45f0add83c806c68e4b06106f90db260585b25ef2abfcda0bd95c0132fd", - "sha256:c4884a65654a7c45435063e14ae85280eb1f111d94e542396717ba9828c4337f" - ], - "version": "==3.1.1" - }, "jmespath": { "hashes": [ "sha256:695cb76fa78a10663425d5b73ddc5714eb711157e52704d69be03b1a02ba4fec", @@ -67,29 +59,29 @@ }, "numpy": { "hashes": [ - "sha256:1598a6de323508cfeed6b7cd6c4efb43324f4692e20d1f76e1feec7f59013448", - "sha256:1b0ece94018ae21163d1f651b527156e1f03943b986188dd81bc7e066eae9d1c", - "sha256:2e40be731ad618cb4974d5ba60d373cdf4f1b8dcbf1dcf4d9dff5e212baf69c5", - "sha256:4ba59db1fcc27ea31368af524dcf874d9277f21fd2e1f7f1e2e0c75ee61419ed", - "sha256:59ca9c6592da581a03d42cc4e270732552243dc45e87248aa8d636d53812f6a5", - "sha256:5e0feb76849ca3e83dd396254e47c7dba65b3fa9ed3df67c2556293ae3e16de3", - "sha256:6d205249a0293e62bbb3898c4c2e1ff8a22f98375a34775a259a0523111a8f6c", - "sha256:6fcc5a3990e269f86d388f165a089259893851437b904f422d301cdce4ff25c8", - "sha256:82847f2765835c8e5308f136bc34018d09b49037ec23ecc42b246424c767056b", - "sha256:87902e5c03355335fc5992a74ba0247a70d937f326d852fc613b7f53516c0963", - "sha256:9ab21d1cb156a620d3999dd92f7d1c86824c622873841d6b080ca5495fa10fef", - "sha256:a1baa1dc8ecd88fb2d2a651671a84b9938461e8a8eed13e2f0a812a94084d1fa", - "sha256:a244f7af80dacf21054386539699ce29bcc64796ed9850c99a34b41305630286", - "sha256:a35af656a7ba1d3decdd4fae5322b87277de8ac98b7d9da657d9e212ece76a61", - "sha256:b1fe1a6f3a6f355f6c29789b5927f8bd4f134a4bd9a781099a7c4f66af8850f5", - "sha256:b5ad0adb51b2dee7d0ee75a69e9871e2ddfb061c73ea8bc439376298141f77f5", - "sha256:ba3c7a2814ec8a176bb71f91478293d633c08582119e713a0c5351c0f77698da", - "sha256:cd77d58fb2acf57c1d1ee2835567cd70e6f1835e32090538f17f8a3a99e5e34b", - "sha256:cdb3a70285e8220875e4d2bc394e49b4988bdb1298ffa4e0bd81b2f613be397c", - "sha256:deb529c40c3f1e38d53d5ae6cd077c21f1d49e13afc7936f7f868455e16b64a0", - "sha256:e7894793e6e8540dbeac77c87b489e331947813511108ae097f1715c018b8f3d" - ], - "version": "==1.18.2" + "sha256:0aa2b318cf81eb1693fcfcbb8007e95e231d7e1aa24288137f3b19905736c3ee", + "sha256:163c78c04f47f26ca1b21068cea25ed7c5ecafe5f5ab2ea4895656a750582b56", + "sha256:1e37626bcb8895c4b3873fcfd54e9bfc5ffec8d0f525651d6985fcc5c6b6003c", + "sha256:264fd15590b3f02a1fbc095e7e1f37cdac698ff3829e12ffdcffdce3772f9d44", + "sha256:3d9e1554cd9b5999070c467b18e5ae3ebd7369f02706a8850816f576a954295f", + "sha256:40c24960cd5cec55222963f255858a1c47c6fa50a65a5b03fd7de75e3700eaaa", + "sha256:46f404314dbec78cb342904f9596f25f9b16e7cf304030f1339e553c8e77f51c", + "sha256:4847f0c993298b82fad809ea2916d857d0073dc17b0510fbbced663b3265929d", + "sha256:48e15612a8357393d176638c8f68a19273676877caea983f8baf188bad430379", + "sha256:6725d2797c65598778409aba8cd67077bb089d5b7d3d87c2719b206dc84ec05e", + "sha256:99f0ba97e369f02a21bb95faa3a0de55991fd5f0ece2e30a9e2eaebeac238921", + "sha256:a41f303b3f9157a31ce7203e3ca757a0c40c96669e72d9b6ee1bce8507638970", + "sha256:a4305564e93f5c4584f6758149fd446df39fd1e0a8c89ca0deb3cce56106a027", + "sha256:a551d8cc267c634774830086da42e4ba157fa41dd3b93982bc9501b284b0c689", + "sha256:a6bc9432c2640b008d5f29bad737714eb3e14bb8854878eacf3d7955c4e91c36", + "sha256:c60175d011a2e551a2f74c84e21e7c982489b96b6a5e4b030ecdeacf2914da68", + "sha256:e46e2384209c91996d5ec16744234d1c906ab79a701ce1a26155c9ec890b8dc8", + "sha256:e607b8cdc2ae5d5a63cd1bec30a15b5ed583ac6a39f04b7ba0f03fcfbf29c05b", + "sha256:e94a39d5c40fffe7696009dbd11bc14a349b377e03a384ed011e03d698787dd3", + "sha256:eb2286249ebfe8fcb5b425e5ec77e4736d53ee56d3ad296f8947f67150f495e3", + "sha256:fdee7540d12519865b423af411bd60ddb513d2eb2cd921149b732854995bbf8b" + ], + "version": "==1.18.3" }, "pandas": { "hashes": [ @@ -114,34 +106,29 @@ }, "pyarrow": { "hashes": [ - "sha256:00abec64636aa506d948926ab5dd37fdfe8c0407b069602ba16c68c19ccb0257", - "sha256:09e9046e3dc24b5c81d307d150b8c04b127aa9f9b3c6babcf13313f2448dd185", - "sha256:2ff6e7b0411e3e163cc6465f1ed6a680f0c78b4ff6a4f507d29eb4ed65860557", - "sha256:53d3f3684ca0cc12b64f2446022e2ab4a9b0b0976bba0f47ea53ea16b6af4ece", - "sha256:5449408037c761a0622d13cc0c21756fcce2ea7346ea9c73e2abf8cdd8385ea2", - "sha256:5af1cc49225aaf82a3dfbda22e5533d339f540921ea001ba36b0d6d5ad364e2b", - "sha256:5fede6cb5d9fda323098042ece0597f40e5bd78520b87e7b8efdd8f062846ad8", - "sha256:7aebec0f1b76e73a6307b5027618c843eadb4dc4f6e1f08ca496a01a7273ac64", - "sha256:8663ca4ca5c27fcb5c8bfc5c7b7e8780b9d699e47da1cad1b7b170eff98498b5", - "sha256:890b9a7d6e2c61968ba93e535fc1cf116e66eea2fcc2d6b2503b44e190f3bc47", - "sha256:899d7316ea5610798c42e13ffb1d73323600168ccd6d8f0d58ce9e665b7a341f", - "sha256:8a00a8497e2367c4f206bb8b7df01852d1e3f1261107ee77a217af654793ac0e", - "sha256:8d212c2c93706fafff39a71bee3d42dfd1ca393fda31ce5e3a05c620e1886a7f", - "sha256:94d89482bb5461c55b2ee33eafd44294c7f1244cc9e390ea7855f647957113f7", - "sha256:a609354433dd31ffc4c8de8637de915391fd6ff781b3d8c5d51d3f4eec6fcf39", - "sha256:ac83d595f9b469bea712ce998270038b08b40794abd7374e4bce2ecf5ee2c1cb", - "sha256:bb6bb7ba1b6a1c3c94cc0d0068c96df9498c973ad0ae6ca398164d339b704c97", - "sha256:c1214f1689711d6562df70863cbd62d6f2a83e68214bb4c97c489f2f97ddeaf4", - "sha256:caf50dfcc709c7cfca4f816e9b4442222e9e6d3ec51c2618fb6bde8a73c59be4", - "sha256:d746e5f34240199ef8afdd0efb391692b85b1ce3e098febd887efc2128da6570", - "sha256:db6d7ec70beeaea468c9c47241f95e2eecfaa2dbb4a27965bf1f952c12680fe9", - "sha256:dcd9347797578b0f65a6fb0cb76f462d5d0d63148f51ac8f9c9b5be9acc3f40e", - "sha256:dd18bc60cef3e72f8082c46de4cfb0cf9fb294c0ff7a201e2b95924fb5d2d146", - "sha256:df8ff1c5de2e454dcab9421d70d0db3985ad4efc40899d947687ca6d36846fc8", - "sha256:e6c042f192c9a0ba33a927a8d0a1e6bfe3ab29aa48a74fc48040d32b07d65124", - "sha256:fab386e5403cec3f66e1ac1375f3648351f9415f28d7740ee0f813d1fc0a326a" - ], - "version": "==0.16.0" + "sha256:0dec3a824469a14dda5fcc5d657db8b3eff8ef3e549f3d9008bf62d221775ce3", + "sha256:116c76d4484151ad3ea7db3b0b1d3ce3b21d7a5707950da4586f42580c8acf4e", + "sha256:291d6179ac6008c07cd1e5f0034aacaf07603252234dbaa48edc55c29b42ddc5", + "sha256:2eb79ff5bd153291822eb3624e710f754172a6c448224a819988652a86075e35", + "sha256:37ad7949eeef178403637b86f9f165fcd36d7259031bc06635b9cd3a6ab2d60c", + "sha256:3b485c8f92e64b4500bb495a1303ebd43ceec16961c95c938c7e8630c6249cd8", + "sha256:40b974e719a190f7404907fa99c786fc860415246e2121e61657b1c8c7c045df", + "sha256:7f4ec14f3c2036508c08fa1bd805847bfb1bb4a1be26e01e799bb594e174f477", + "sha256:8b2f5843271b4134c94ec0a632c26bd56b0598ddaa0e123d773107bbb8f70c6a", + "sha256:99a730f4a1860a47a8566644688457da3eb3793df397e2b8eacd53c5d652f045", + "sha256:a362331ac8a7a7c6b7684e205af6046911e495a3d19ccf2b300e3081c39321d1", + "sha256:b6492d8e35392f720a74bc7a0bb9680b6c8b615d3beefaa1e8b9634fcef2a78d", + "sha256:b8fb9b086f2bb5baefff80e0ae78712e95470954ad4dc1b0aa8450b54d63790a", + "sha256:c0574ce60d714d2de5076d8fb2a6dd1e77350803f1bcfb15422d03621e4c0db1", + "sha256:c17f3d4bfd2c9d5c88ef9d1acabcb31f9015edafcbfb7faf9fd4e5a9e7a5012d", + "sha256:c2543fce53db942456b33ec8a4103a35b58648b27ff9fb93f405ff25b034546b", + "sha256:c8f428cd9885c5727c17e1af3c850b67f2bed804e2ed88022908e159fcd83ac6", + "sha256:d3ccfe1408c93abefa5ea79a295782136c8ad9bb8efd025b9df4685e081f021e", + "sha256:de39a27b339247a2b9ced01dd91051a5525d2fbaf66f1db6cf9242634fb3365d", + "sha256:eb6f90342909ba50abde964e529f3165068dd76ffd2c15408894f0cb0bc3e581", + "sha256:fb1cdfda872700feee8ce8cc1f4a7e4220728d6e31eb067a0f95595f11e724f3" + ], + "version": "==0.17.0" }, "python-dateutil": { "hashes": [ @@ -152,10 +139,10 @@ }, "pytz": { "hashes": [ - "sha256:1c557d7d0e871de1f5ccd5833f60fb2550652da6be2693c1e02300743d21500d", - "sha256:b02c06db6cf09c12dd25137e563b31700d3b80fcc4ad23abb7a315f2789819be" + "sha256:a494d53b6d39c3c6e44c3bec237336e14305e4f29bbf800b599253057fbb79ed", + "sha256:c35965d010ce31b23eeb663ed3cc8c906275d6be1a34393a1d73a41febf4a048" ], - "version": "==2019.3" + "version": "==2020.1" }, "s3transfer": { "hashes": [ @@ -171,20 +158,44 @@ ], "version": "==1.14.0" }, + "sqlalchemy": { + "hashes": [ + "sha256:083e383a1dca8384d0ea6378bd182d83c600ed4ff4ec8247d3b2442cf70db1ad", + "sha256:0a690a6486658d03cc6a73536d46e796b6570ac1f8a7ec133f9e28c448b69828", + "sha256:114b6ace30001f056e944cebd46daef38fdb41ebb98f5e5940241a03ed6cad43", + "sha256:128f6179325f7597a46403dde0bf148478f868df44841348dfc8d158e00db1f9", + "sha256:13d48cd8b925b6893a4e59b2dfb3e59a5204fd8c98289aad353af78bd214db49", + "sha256:211a1ce7e825f7142121144bac76f53ac28b12172716a710f4bf3eab477e730b", + "sha256:2dc57ee80b76813759cccd1a7affedf9c4dbe5b065a91fb6092c9d8151d66078", + "sha256:3e625e283eecc15aee5b1ef77203bfb542563fa4a9aa622c7643c7b55438ff49", + "sha256:43078c7ec0457387c79b8d52fff90a7ad352ca4c7aa841c366238c3e2cf52fdf", + "sha256:5b1bf3c2c2dca738235ce08079783ef04f1a7fc5b21cf24adaae77f2da4e73c3", + "sha256:6056b671aeda3fc451382e52ab8a753c0d5f66ef2a5ccc8fa5ba7abd20988b4d", + "sha256:68d78cf4a9dfade2e6cf57c4be19f7b82ed66e67dacf93b32bb390c9bed12749", + "sha256:7025c639ce7e170db845e94006cf5f404e243e6fc00d6c86fa19e8ad8d411880", + "sha256:7224e126c00b8178dfd227bc337ba5e754b197a3867d33b9f30dc0208f773d70", + "sha256:7d98e0785c4cd7ae30b4a451416db71f5724a1839025544b4edbd92e00b91f0f", + "sha256:8d8c21e9d4efef01351bf28513648ceb988031be4159745a7ad1b3e28c8ff68a", + "sha256:bbb545da054e6297242a1bb1ba88e7a8ffb679f518258d66798ec712b82e4e07", + "sha256:d00b393f05dbd4ecd65c989b7f5a81110eae4baea7a6a4cdd94c20a908d1456e", + "sha256:e18752cecaef61031252ca72031d4d6247b3212ebb84748fc5d1a0d2029c23ea" + ], + "version": "==1.3.16" + }, "tenacity": { "hashes": [ - "sha256:f7bcbf5bb53875cfc38f61f596a88b1c994af32420f120c4409542a683ad613b", - "sha256:fb01d8ef2474eed422d8314a566c9e391cccb6e50cf4585022add6cb5cda66c8" + "sha256:29ae90e7faf488a8628432154bb34ace1cca58244c6ea399fd33f066ac71339a", + "sha256:5a5d3dcd46381abe8b4f82b5736b8726fd3160c6c7161f53f8af7f1eb9b82173" ], - "version": "==6.1.0" + "version": "==6.2.0" }, "urllib3": { "hashes": [ - "sha256:2f3db8b19923a873b3e5256dc9c2dedfa883e33d87c690d9c7913e1f40673cdc", - "sha256:87716c2d2a7121198ebcb7ce7cccf6ce5e9ba539041cfbaeecfb641dc0bf6acc" + "sha256:3018294ebefce6572a474f0604c2021e33b3fd8006ecd11d62107a5d2a963527", + "sha256:88206b0eb87e6d677d424843ac5209e3fb9d0190d0ee169599165ec25e9d9115" ], "markers": "python_version != '3.4'", - "version": "==1.25.8" + "version": "==1.25.9" } }, "develop": { @@ -211,33 +222,38 @@ }, "awscli": { "hashes": [ - "sha256:4c49f085fb827ca1aeba5e6e5e39f6005110a0059b5c772aeb1d51c4f33c4028", - "sha256:9459ac705c2a5d8724057492800c52084df714b624853eb3331087ecf8726a22" + "sha256:7834a373a859cde9273ab2f17774c05a0812ecafe9934d69a73f8d9dfd99290b" ], "index": "pypi", - "version": "==1.17.9" + "version": "==1.18.51" + }, + "black": { + "hashes": [ + "sha256:1b30e59be925fafc1ee4565e5e08abef6b03fe455102883820fe5ee2e4734e0b", + "sha256:c2edb73a08e9e0e6f65a0e6af18b059b8b1cdd5bef997d7a0b181df93dc81539" + ], + "version": "==19.10b0" }, "bleach": { "hashes": [ - "sha256:cc8da25076a1fe56c3ac63671e2194458e0c4d9c7becfd52ca251650d517903c", - "sha256:e78e426105ac07026ba098f04de8abe9b6e3e98b5befbf89b51a5ef0a4292b03" + "sha256:2bce3d8fab545a6528c8fa5d9f9ae8ebc85a56da365c7f85180bfe96a35ef22f", + "sha256:3c4c520fdb9db59ef139915a5db79f8b51bc2a7257ea0389f30c846883430a4b" ], - "index": "pypi", - "version": "==3.1.4" + "version": "==3.1.5" }, "botocore": { "hashes": [ - "sha256:3ea89601ee452b65084005278bd832be854cfde5166685dcb14b6c8f19d3fc6d", - "sha256:a963af564d94107787ff3d2c534e8b7aed7f12e014cdd609f8fcb17bf9d9b19a" + "sha256:879cedb22baf9446323240f1cf57d4e0e7ba262ba6fde6d3540cf7fdd7ddad34", + "sha256:c6d3dd976c1462a4e0f2dabde09c38de2a641e6b6bb0af505ac8465735796367" ], - "version": "==1.15.32" + "version": "==1.16.1" }, "certifi": { "hashes": [ - "sha256:017c25db2a153ce562900032d5bc68e9f191e44e9a0f762f373977de9df1fbb3", - "sha256:25b64c7da4cd7479594d035c08c2d809eb4aab3a26e5a990ea98cc450c320f1f" + "sha256:1d987a998c75633c40847cc966fcf5904906c920a7f17ef374f5aa4282abd304", + "sha256:51fcb31174be6e6664c5f69e3e1691a2d72a1a12e90f872cbdb1567eb47b6519" ], - "version": "==2019.11.28" + "version": "==2020.4.5.1" }, "cffi": { "hashes": [ @@ -279,74 +295,79 @@ ], "version": "==3.0.4" }, + "click": { + "hashes": [ + "sha256:d2b5255c7c6349bc1bd1e59e08cd12acbbd63ce649f2588755783aa94dfb6b1a", + "sha256:dacca89f4bfadd5de3d7489b7c8a566eee0d3676333fbb50030263894c38c0dc" + ], + "version": "==7.1.2" + }, "colorama": { "hashes": [ - "sha256:05eed71e2e327246ad6b38c540c4a3117230b19679b875190486ddd2d721422d", - "sha256:f8ac84de7840f5b9c4e3347b3c1eaa50f7e49c2b07596221daec5edaabbd7c48" + "sha256:7d73d2a99753107a36ac6b455ee49046802e59d9d076ef8e47b61499fa29afff", + "sha256:e96da0d330793e2cb9485e9ddfd918d456036c7149416295932478192f4436a1" ], - "version": "==0.4.1" + "version": "==0.4.3" }, "coverage": { "hashes": [ - "sha256:03f630aba2b9b0d69871c2e8d23a69b7fe94a1e2f5f10df5049c0df99db639a0", - "sha256:046a1a742e66d065d16fb564a26c2a15867f17695e7f3d358d7b1ad8a61bca30", - "sha256:0a907199566269e1cfa304325cc3b45c72ae341fbb3253ddde19fa820ded7a8b", - "sha256:165a48268bfb5a77e2d9dbb80de7ea917332a79c7adb747bd005b3a07ff8caf0", - "sha256:1b60a95fc995649464e0cd48cecc8288bac5f4198f21d04b8229dc4097d76823", - "sha256:1f66cf263ec77af5b8fe14ef14c5e46e2eb4a795ac495ad7c03adc72ae43fafe", - "sha256:2e08c32cbede4a29e2a701822291ae2bc9b5220a971bba9d1e7615312efd3037", - "sha256:3844c3dab800ca8536f75ae89f3cf566848a3eb2af4d9f7b1103b4f4f7a5dad6", - "sha256:408ce64078398b2ee2ec08199ea3fcf382828d2f8a19c5a5ba2946fe5ddc6c31", - "sha256:443be7602c790960b9514567917af538cac7807a7c0c0727c4d2bbd4014920fd", - "sha256:4482f69e0701139d0f2c44f3c395d1d1d37abd81bfafbf9b6efbe2542679d892", - "sha256:4a8a259bf990044351baf69d3b23e575699dd60b18460c71e81dc565f5819ac1", - "sha256:513e6526e0082c59a984448f4104c9bf346c2da9961779ede1fc458e8e8a1f78", - "sha256:5f587dfd83cb669933186661a351ad6fc7166273bc3e3a1531ec5c783d997aac", - "sha256:62061e87071497951155cbccee487980524d7abea647a1b2a6eb6b9647df9006", - "sha256:641e329e7f2c01531c45c687efcec8aeca2a78a4ff26d49184dce3d53fc35014", - "sha256:65a7e00c00472cd0f59ae09d2fb8a8aaae7f4a0cf54b2b74f3138d9f9ceb9cb2", - "sha256:6ad6ca45e9e92c05295f638e78cd42bfaaf8ee07878c9ed73e93190b26c125f7", - "sha256:73aa6e86034dad9f00f4bbf5a666a889d17d79db73bc5af04abd6c20a014d9c8", - "sha256:7c9762f80a25d8d0e4ab3cb1af5d9dffbddb3ee5d21c43e3474c84bf5ff941f7", - "sha256:85596aa5d9aac1bf39fe39d9fa1051b0f00823982a1de5766e35d495b4a36ca9", - "sha256:86a0ea78fd851b313b2e712266f663e13b6bc78c2fb260b079e8b67d970474b1", - "sha256:8a620767b8209f3446197c0e29ba895d75a1e272a36af0786ec70fe7834e4307", - "sha256:922fb9ef2c67c3ab20e22948dcfd783397e4c043a5c5fa5ff5e9df5529074b0a", - "sha256:9fad78c13e71546a76c2f8789623eec8e499f8d2d799f4b4547162ce0a4df435", - "sha256:a37c6233b28e5bc340054cf6170e7090a4e85069513320275a4dc929144dccf0", - "sha256:c3fc325ce4cbf902d05a80daa47b645d07e796a80682c1c5800d6ac5045193e5", - "sha256:cda33311cb9fb9323958a69499a667bd728a39a7aa4718d7622597a44c4f1441", - "sha256:db1d4e38c9b15be1521722e946ee24f6db95b189d1447fa9ff18dd16ba89f732", - "sha256:eda55e6e9ea258f5e4add23bcf33dc53b2c319e70806e180aecbff8d90ea24de", - "sha256:f372cdbb240e09ee855735b9d85e7f50730dcfb6296b74b95a3e5dea0615c4c1" - ], - "version": "==5.0.4" + "sha256:00f1d23f4336efc3b311ed0d807feb45098fc86dee1ca13b3d6768cdab187c8a", + "sha256:01333e1bd22c59713ba8a79f088b3955946e293114479bbfc2e37d522be03355", + "sha256:0cb4be7e784dcdc050fc58ef05b71aa8e89b7e6636b99967fadbdba694cf2b65", + "sha256:0e61d9803d5851849c24f78227939c701ced6704f337cad0a91e0972c51c1ee7", + "sha256:1601e480b9b99697a570cea7ef749e88123c04b92d84cedaa01e117436b4a0a9", + "sha256:2742c7515b9eb368718cd091bad1a1b44135cc72468c731302b3d641895b83d1", + "sha256:2d27a3f742c98e5c6b461ee6ef7287400a1956c11421eb574d843d9ec1f772f0", + "sha256:402e1744733df483b93abbf209283898e9f0d67470707e3c7516d84f48524f55", + "sha256:5c542d1e62eece33c306d66fe0a5c4f7f7b3c08fecc46ead86d7916684b36d6c", + "sha256:5f2294dbf7875b991c381e3d5af2bcc3494d836affa52b809c91697449d0eda6", + "sha256:6402bd2fdedabbdb63a316308142597534ea8e1895f4e7d8bf7476c5e8751fef", + "sha256:66460ab1599d3cf894bb6baee8c684788819b71a5dc1e8fa2ecc152e5d752019", + "sha256:782caea581a6e9ff75eccda79287daefd1d2631cc09d642b6ee2d6da21fc0a4e", + "sha256:79a3cfd6346ce6c13145731d39db47b7a7b859c0272f02cdb89a3bdcbae233a0", + "sha256:7a5bdad4edec57b5fb8dae7d3ee58622d626fd3a0be0dfceda162a7035885ecf", + "sha256:8fa0cbc7ecad630e5b0f4f35b0f6ad419246b02bc750de7ac66db92667996d24", + "sha256:a027ef0492ede1e03a8054e3c37b8def89a1e3c471482e9f046906ba4f2aafd2", + "sha256:a3f3654d5734a3ece152636aad89f58afc9213c6520062db3978239db122f03c", + "sha256:a82b92b04a23d3c8a581fc049228bafde988abacba397d57ce95fe95e0338ab4", + "sha256:acf3763ed01af8410fc36afea23707d4ea58ba7e86a8ee915dfb9ceff9ef69d0", + "sha256:adeb4c5b608574a3d647011af36f7586811a2c1197c861aedb548dd2453b41cd", + "sha256:b83835506dfc185a319031cf853fa4bb1b3974b1f913f5bb1a0f3d98bdcded04", + "sha256:bb28a7245de68bf29f6fb199545d072d1036a1917dca17a1e75bbb919e14ee8e", + "sha256:bf9cb9a9fd8891e7efd2d44deb24b86d647394b9705b744ff6f8261e6f29a730", + "sha256:c317eaf5ff46a34305b202e73404f55f7389ef834b8dbf4da09b9b9b37f76dd2", + "sha256:dbe8c6ae7534b5b024296464f387d57c13caa942f6d8e6e0346f27e509f0f768", + "sha256:de807ae933cfb7f0c7d9d981a053772452217df2bf38e7e6267c9cbf9545a796", + "sha256:dead2ddede4c7ba6cb3a721870f5141c97dc7d85a079edb4bd8d88c3ad5b20c7", + "sha256:dec5202bfe6f672d4511086e125db035a52b00f1648d6407cc8e526912c0353a", + "sha256:e1ea316102ea1e1770724db01998d1603ed921c54a86a2efcb03428d5417e489", + "sha256:f90bfc4ad18450c80b024036eaf91e4a246ae287701aaa88eaebebf150868052" + ], + "version": "==5.1" }, "cryptography": { "hashes": [ - "sha256:02079a6addc7b5140ba0825f542c0869ff4df9a69c360e339ecead5baefa843c", - "sha256:1df22371fbf2004c6f64e927668734070a8953362cd8370ddd336774d6743595", - "sha256:369d2346db5934345787451504853ad9d342d7f721ae82d098083e1f49a582ad", - "sha256:3cda1f0ed8747339bbdf71b9f38ca74c7b592f24f65cdb3ab3765e4b02871651", - "sha256:44ff04138935882fef7c686878e1c8fd80a723161ad6a98da31e14b7553170c2", - "sha256:4b1030728872c59687badcca1e225a9103440e467c17d6d1730ab3d2d64bfeff", - "sha256:58363dbd966afb4f89b3b11dfb8ff200058fbc3b947507675c19ceb46104b48d", - "sha256:6ec280fb24d27e3d97aa731e16207d58bd8ae94ef6eab97249a2afe4ba643d42", - "sha256:7270a6c29199adc1297776937a05b59720e8a782531f1f122f2eb8467f9aab4d", - "sha256:73fd30c57fa2d0a1d7a49c561c40c2f79c7d6c374cc7750e9ac7c99176f6428e", - "sha256:7f09806ed4fbea8f51585231ba742b58cbcfbfe823ea197d8c89a5e433c7e912", - "sha256:90df0cc93e1f8d2fba8365fb59a858f51a11a394d64dbf3ef844f783844cc793", - "sha256:971221ed40f058f5662a604bd1ae6e4521d84e6cad0b7b170564cc34169c8f13", - "sha256:a518c153a2b5ed6b8cc03f7ae79d5ffad7315ad4569b2d5333a13c38d64bd8d7", - "sha256:b0de590a8b0979649ebeef8bb9f54394d3a41f66c5584fff4220901739b6b2f0", - "sha256:b43f53f29816ba1db8525f006fa6f49292e9b029554b3eb56a189a70f2a40879", - "sha256:d31402aad60ed889c7e57934a03477b572a03af7794fa8fb1780f21ea8f6551f", - "sha256:de96157ec73458a7f14e3d26f17f8128c959084931e8997b9e655a39c8fde9f9", - "sha256:df6b4dca2e11865e6cfbfb708e800efb18370f5a46fd601d3755bc7f85b3a8a2", - "sha256:ecadccc7ba52193963c0475ac9f6fa28ac01e01349a2ca48509667ef41ffd2cf", - "sha256:fb81c17e0ebe3358486cd8cc3ad78adbae58af12fc2bf2bc0bb84e8090fa5ce8" - ], - "version": "==2.8" + "sha256:091d31c42f444c6f519485ed528d8b451d1a0c7bf30e8ca583a0cac44b8a0df6", + "sha256:18452582a3c85b96014b45686af264563e3e5d99d226589f057ace56196ec78b", + "sha256:1dfa985f62b137909496e7fc182dac687206d8d089dd03eaeb28ae16eec8e7d5", + "sha256:1e4014639d3d73fbc5ceff206049c5a9a849cefd106a49fa7aaaa25cc0ce35cf", + "sha256:22e91636a51170df0ae4dcbd250d318fd28c9f491c4e50b625a49964b24fe46e", + "sha256:3b3eba865ea2754738616f87292b7f29448aec342a7c720956f8083d252bf28b", + "sha256:651448cd2e3a6bc2bb76c3663785133c40d5e1a8c1a9c5429e4354201c6024ae", + "sha256:726086c17f94747cedbee6efa77e99ae170caebeb1116353c6cf0ab67ea6829b", + "sha256:844a76bc04472e5135b909da6aed84360f522ff5dfa47f93e3dd2a0b84a89fa0", + "sha256:88c881dd5a147e08d1bdcf2315c04972381d026cdb803325c03fe2b4a8ed858b", + "sha256:96c080ae7118c10fcbe6229ab43eb8b090fccd31a09ef55f83f690d1ef619a1d", + "sha256:a0c30272fb4ddda5f5ffc1089d7405b7a71b0b0f51993cb4e5dbb4590b2fc229", + "sha256:bb1f0281887d89617b4c68e8db9a2c42b9efebf2702a3c5bf70599421a8623e3", + "sha256:c447cf087cf2dbddc1add6987bbe2f767ed5317adb2d08af940db517dd704365", + "sha256:c4fd17d92e9d55b84707f4fd09992081ba872d1a0c610c109c18e062e06a2e55", + "sha256:d0d5aeaedd29be304848f1c5059074a740fa9f6f26b84c5b63e8b29e73dfc270", + "sha256:daf54a4b07d67ad437ff239c8a4080cfd1cc7213df57d33c97de7b4738048d5e", + "sha256:e993468c859d084d5579e2ebee101de8f5a27ce8e2159959b6673b418fd8c785", + "sha256:f118a95c7480f5be0df8afeb9a11bd199aa20afab7a96bcf20409b411a3a85f0" + ], + "version": "==2.9.2" }, "distlib": { "hashes": [ @@ -362,13 +383,6 @@ ], "version": "==0.15.2" }, - "entrypoints": { - "hashes": [ - "sha256:589f874b313739ad35be6e0cd7efde2a4e9b6fea91edcc34e58ecbb8dbe56d19", - "sha256:c70dd71abe5a8c85e55e12c19bd91ccfeec11a6e99044204511f9ed547d48451" - ], - "version": "==0.3" - }, "execnet": { "hashes": [ "sha256:cacb9df31c9680ec5f95553976c4da484d407e85e41c83cb812aa014f0eddc50", @@ -385,10 +399,10 @@ }, "flake8": { "hashes": [ - "sha256:45681a117ecc81e870cbf1262835ae4af5e7a8b08e40b944a8a6e6b895914cfb", - "sha256:49356e766643ad15072a789a20915d3c91dc89fd313ccd71802303fd67e4deca" + "sha256:c09e7e4ea0d91fa36f7b8439ca158e592be56524f0b67c39ab0ea2b85ed8f9a4", + "sha256:f33c5320eaa459cdee6367016a4bf4ba2a9b81499ce56e6a32abbf0b8d3a2eb4" ], - "version": "==3.7.9" + "version": "==3.8.0a2" }, "idna": { "hashes": [ @@ -397,6 +411,29 @@ ], "version": "==2.9" }, + "importlib-metadata": { + "hashes": [ + "sha256:2a688cbaa90e0cc587f1df48bdc97a6eadccdcd9c35fb3f976a09e3b5016d90f", + "sha256:34513a8a0c4962bc66d35b359558fd8a5e10cd472d37aec5f66858addef32c1e" + ], + "markers": "python_version < '3.8'", + "version": "==1.6.0" + }, + "importlib-resources": { + "hashes": [ + "sha256:6f87df66833e1942667108628ec48900e02a4ab4ad850e25fbf07cb17cf734ca", + "sha256:85dc0b9b325ff78c8bef2e4ff42616094e16b98ebd5e3b50fe7e2f0bbcdcde49" + ], + "markers": "python_version < '3.7'", + "version": "==1.5.0" + }, + "isort": { + "hashes": [ + "sha256:54da7e92468955c4fceacd0c86bd0ec997b0e1ee80d97f67c35a78b719dccab1", + "sha256:6e811fcb295968434526407adb8796944f1988c5b65e8139058f2014cbe100fd" + ], + "version": "==4.3.21" + }, "jeepney": { "hashes": [ "sha256:3479b861cc2b6407de5188695fa1a8d57e5072d7059322469b62628869b8e36e", @@ -414,10 +451,10 @@ }, "keyring": { "hashes": [ - "sha256:197fd5903901030ef7b82fe247f43cfed2c157a28e7747d1cfcf4bc5e699dd03", - "sha256:8179b1cdcdcbc221456b5b74e6b7cfa06f8dd9f239eb81892166d9223d82c5ba" + "sha256:3401234209015144a5d75701e71cb47239e552b0882313e9f51e8976f9e27843", + "sha256:c53e0e5ccde3ad34284a40ce7976b5b3a3d6de70344c3f8ee44364cc340976ec" ], - "version": "==21.2.0" + "version": "==21.2.1" }, "mccabe": { "hashes": [ @@ -440,6 +477,13 @@ ], "version": "==20.3" }, + "pathspec": { + "hashes": [ + "sha256:7d91249d21749788d07a2d0f94147accd8f845507400749ea19c1ec9054a12b0", + "sha256:da45173eb3a6f2a5a487efba21f050af2b41948be6ab52b6a1e3ff22bb8b7061" + ], + "version": "==0.8.0" + }, "pkginfo": { "hashes": [ "sha256:7424f2c8511c186cd5424bbf31045b77435b37a8d604990b79d4e70d741148bb", @@ -470,10 +514,10 @@ }, "pycodestyle": { "hashes": [ - "sha256:95a2219d12372f05704562a14ec30bc76b05a5b297b21a5dfe3f6fac3491ae56", - "sha256:e40a936c9a450ad81df37f549d676d127b1b66000a6c500caa2b085bc0ca976c" + "sha256:933bfe8d45355fbb35f9017d81fc51df8cb7ce58b82aca2568b870bf7bea1611", + "sha256:c1362bf675a7c0171fa5f795917c570c2e405a97e5dc473b51f3656075d73acc" ], - "version": "==2.5.0" + "version": "==2.6.0a1" }, "pycparser": { "hashes": [ @@ -484,10 +528,10 @@ }, "pyflakes": { "hashes": [ - "sha256:17dbeb2e3f4d772725c777fabc446d5634d1038f234e77343108ce445ea69ce0", - "sha256:d976835886f8c5b31d47970ed689944a0262b5f3afa00a5a7b4dc81e5449f8a2" + "sha256:0d94e0e05a19e57a99444b6ddcf9a6eb2e5c68d3ca1e98e90707af8152c90a92", + "sha256:35b2d75ee967ea93b55750aa9edbbf72813e06a66ba54438df2cfac9e3c27fc8" ], - "version": "==2.1.1" + "version": "==2.2.0" }, "pygments": { "hashes": [ @@ -498,18 +542,25 @@ }, "pyparsing": { "hashes": [ - "sha256:4c830582a84fb022400b85429791bc551f1f4871c33f23e44f353119e92f969f", - "sha256:c342dccb5250c08d45fd6f8b4a559613ca603b57498511740e65cd11a2e7dcec" + "sha256:67199f0c41a9c702154efb0e7a8cc08accf830eb003b4d9fa42c4059002e2492", + "sha256:700d17888d441604b0bd51535908dcb297561b040819cccde647a92439db5a2a" ], - "version": "==2.4.6" + "version": "==3.0.0a1" }, "pytest": { "hashes": [ - "sha256:0d5fe9189a148acc3c3eb2ac8e1ac0742cb7618c084f3d228baaec0c254b318d", - "sha256:ff615c761e25eb25df19edddc0b970302d2a9091fbce0e7213298d85fb61fef6" + "sha256:0e5b30f5cb04e887b91b1ee519fa3d89049595f428c1db76e73bd7f17b09b172", + "sha256:84dde37075b8805f3d1f392cc47e38a0e59518fb46a431cfdaf7cf1ce805f970" + ], + "index": "pypi", + "version": "==5.4.1" + }, + "pytest-black": { + "hashes": [ + "sha256:01a9a7acc69e618ebf3f834932a4d7a81909f6911051d0871b0ed4de3cbe9712" ], "index": "pypi", - "version": "==5.3.5" + "version": "==0.3.8" }, "pytest-cov": { "hashes": [ @@ -521,11 +572,11 @@ }, "pytest-flake8": { "hashes": [ - "sha256:4d225c13e787471502ff94409dcf6f7927049b2ec251c63b764a4b17447b60c0", - "sha256:d7e2b6b274a255b7ae35e9224c85294b471a83b76ecb6bd53c337ae977a499af" + "sha256:6e26d94ad41184d9a5113a90179b303efddb53eda505f827418ca78f5b39403a", + "sha256:d85efaafbdb9580791cfa8671799dd40d482fc30bd4476c1ca5efd661e751333" ], "index": "pypi", - "version": "==1.0.4" + "version": "==1.0.5" }, "pytest-forked": { "hashes": [ @@ -534,6 +585,14 @@ ], "version": "==1.1.3" }, + "pytest-isort": { + "hashes": [ + "sha256:5d47dd4c45a7c2eb4a0401ae4febe143724dd8a2acf1e7317c80145bac8b608a", + "sha256:758156cb4dc1db72adc1b7e253011f5eea117fab32af03cedb4cbfc6058b5f8f" + ], + "index": "pypi", + "version": "==1.0.0" + }, "pytest-xdist": { "hashes": [ "sha256:0f46020d3d9619e6d17a65b5b989c1ebbb58fc7b1da8fb126d70f4bac4dfeed1", @@ -551,26 +610,52 @@ }, "pyyaml": { "hashes": [ - "sha256:0e7f69397d53155e55d10ff68fdfb2cf630a35e6daf65cf0bdeaf04f127c09dc", - "sha256:2e9f0b7c5914367b0916c3c104a024bb68f269a486b9d04a2e8ac6f6597b7803", - "sha256:35ace9b4147848cafac3db142795ee42deebe9d0dad885ce643928e88daebdcc", - "sha256:38a4f0d114101c58c0f3a88aeaa44d63efd588845c5a2df5290b73db8f246d15", - "sha256:483eb6a33b671408c8529106df3707270bfacb2447bf8ad856a4b4f57f6e3075", - "sha256:4b6be5edb9f6bb73680f5bf4ee08ff25416d1400fbd4535fe0069b2994da07cd", - "sha256:7f38e35c00e160db592091751d385cd7b3046d6d51f578b29943225178257b31", - "sha256:8100c896ecb361794d8bfdb9c11fce618c7cf83d624d73d5ab38aef3bc82d43f", - "sha256:c0ee8eca2c582d29c3c2ec6e2c4f703d1b7f1fb10bc72317355a746057e7346c", - "sha256:e4c015484ff0ff197564917b4b4246ca03f411b9bd7f16e02a2f586eb48b6d04", - "sha256:ebc4ed52dcc93eeebeae5cf5deb2ae4347b3a81c3fa12b0b8c976544829396a4" + "sha256:06a0d7ba600ce0b2d2fe2e78453a470b5a6e000a985dd4a4e54e436cc36b0e97", + "sha256:240097ff019d7c70a4922b6869d8a86407758333f02203e0fc6ff79c5dcede76", + "sha256:4f4b913ca1a7319b33cfb1369e91e50354d6f07a135f3b901aca02aa95940bd2", + "sha256:69f00dca373f240f842b2931fb2c7e14ddbacd1397d57157a9b005a6a9942648", + "sha256:73f099454b799e05e5ab51423c7bcf361c58d3206fa7b0d555426b1f4d9a3eaf", + "sha256:74809a57b329d6cc0fdccee6318f44b9b8649961fa73144a98735b0aaf029f1f", + "sha256:7739fc0fa8205b3ee8808aea45e968bc90082c10aef6ea95e855e10abf4a37b2", + "sha256:95f71d2af0ff4227885f7a6605c37fd53d3a106fcab511b8860ecca9fcf400ee", + "sha256:b8eac752c5e14d3eca0e6dd9199cd627518cb5ec06add0de9d32baeee6fe645d", + "sha256:cc8955cfbfc7a115fa81d85284ee61147059a753344bc51098f3ccd69b0d7e0c", + "sha256:d13155f591e6fcc1ec3b30685d50bf0711574e2c0dfffd7644babf8b5102ca1a" ], - "version": "==5.2" + "version": "==5.3.1" }, "readme-renderer": { "hashes": [ - "sha256:1b6d8dd1673a0b293766b4106af766b6eff3654605f9c4f239e65de6076bc222", - "sha256:e67d64242f0174a63c3b727801a2fff4c1f38ebe5d71d95ff7ece081945a6cd4" - ], - "version": "==25.0" + "sha256:cbe9db71defedd2428a1589cdc545f9bd98e59297449f69d721ef8f1cfced68d", + "sha256:cc4957a803106e820d05d14f71033092537a22daa4f406dfbdd61177e0936376" + ], + "version": "==26.0" + }, + "regex": { + "hashes": [ + "sha256:08119f707f0ebf2da60d2f24c2f39ca616277bb67ef6c92b72cbf90cbe3a556b", + "sha256:0ce9537396d8f556bcfc317c65b6a0705320701e5ce511f05fc04421ba05b8a8", + "sha256:1cbe0fa0b7f673400eb29e9ef41d4f53638f65f9a2143854de6b1ce2899185c3", + "sha256:2294f8b70e058a2553cd009df003a20802ef75b3c629506be20687df0908177e", + "sha256:23069d9c07e115537f37270d1d5faea3e0bdded8279081c4d4d607a2ad393683", + "sha256:24f4f4062eb16c5bbfff6a22312e8eab92c2c99c51a02e39b4eae54ce8255cd1", + "sha256:295badf61a51add2d428a46b8580309c520d8b26e769868b922750cf3ce67142", + "sha256:2a3bf8b48f8e37c3a40bb3f854bf0121c194e69a650b209628d951190b862de3", + "sha256:4385f12aa289d79419fede43f979e372f527892ac44a541b5446617e4406c468", + "sha256:5635cd1ed0a12b4c42cce18a8d2fb53ff13ff537f09de5fd791e97de27b6400e", + "sha256:5bfed051dbff32fd8945eccca70f5e22b55e4148d2a8a45141a3b053d6455ae3", + "sha256:7e1037073b1b7053ee74c3c6c0ada80f3501ec29d5f46e42669378eae6d4405a", + "sha256:90742c6ff121a9c5b261b9b215cb476eea97df98ea82037ec8ac95d1be7a034f", + "sha256:a58dd45cb865be0ce1d5ecc4cfc85cd8c6867bea66733623e54bd95131f473b6", + "sha256:c087bff162158536387c53647411db09b6ee3f9603c334c90943e97b1052a156", + "sha256:c162a21e0da33eb3d31a3ac17a51db5e634fc347f650d271f0305d96601dc15b", + "sha256:c9423a150d3a4fc0f3f2aae897a59919acd293f4cb397429b120a5fcd96ea3db", + "sha256:ccccdd84912875e34c5ad2d06e1989d890d43af6c2242c6fcfa51556997af6cd", + "sha256:e91ba11da11cf770f389e47c3f5c30473e6d85e06d7fd9dcba0017d2867aab4a", + "sha256:ea4adf02d23b437684cd388d557bf76e3afa72f7fed5bbc013482cc00c816948", + "sha256:fb95debbd1a824b2c4376932f2216cc186912e389bdb0e27147778cf6acb3f89" + ], + "version": "==2020.4.4" }, "requests": { "hashes": [ @@ -624,18 +709,18 @@ }, "tox": { "hashes": [ - "sha256:06ba73b149bf838d5cd25dc30c2dd2671ae5b2757cf98e5c41a35fe449f131b3", - "sha256:806d0a9217584558cc93747a945a9d9bff10b141a5287f0c8429a08828a22192" + "sha256:a4a6689045d93c208d77230853b28058b7513f5123647b67bf012f82fa168303", + "sha256:b2c4b91c975ea5c11463d9ca00bebf82654439c5df0f614807b9bdec62cc9471" ], "index": "pypi", - "version": "==3.14.3" + "version": "==3.14.6" }, "tqdm": { "hashes": [ - "sha256:03d2366c64d44c7f61e74c700d9b202d57e9efe355ea5c28814c52bfe7a50b8c", - "sha256:be5ddeec77d78ba781ea41eacb2358a77f74cc2407f54b82222d7ee7dc8c8ccf" + "sha256:00339634a22c10a7a22476ee946bbde2dbe48d042ded784e4d88e0236eca5d81", + "sha256:ea9e3fd6bd9a37e8783d75bfc4c1faf3c6813da6bd1c3e776488b41ec683af94" ], - "version": "==4.44.1" + "version": "==4.45.0" }, "twine": { "hashes": [ @@ -645,20 +730,46 @@ "index": "pypi", "version": "==3.1.1" }, + "typed-ast": { + "hashes": [ + "sha256:0666aa36131496aed8f7be0410ff974562ab7eeac11ef351def9ea6fa28f6355", + "sha256:0c2c07682d61a629b68433afb159376e24e5b2fd4641d35424e462169c0a7919", + "sha256:249862707802d40f7f29f6e1aad8d84b5aa9e44552d2cc17384b209f091276aa", + "sha256:24995c843eb0ad11a4527b026b4dde3da70e1f2d8806c99b7b4a7cf491612652", + "sha256:269151951236b0f9a6f04015a9004084a5ab0d5f19b57de779f908621e7d8b75", + "sha256:4083861b0aa07990b619bd7ddc365eb7fa4b817e99cf5f8d9cf21a42780f6e01", + "sha256:498b0f36cc7054c1fead3d7fc59d2150f4d5c6c56ba7fb150c013fbc683a8d2d", + "sha256:4e3e5da80ccbebfff202a67bf900d081906c358ccc3d5e3c8aea42fdfdfd51c1", + "sha256:6daac9731f172c2a22ade6ed0c00197ee7cc1221aa84cfdf9c31defeb059a907", + "sha256:715ff2f2df46121071622063fc7543d9b1fd19ebfc4f5c8895af64a77a8c852c", + "sha256:73d785a950fc82dd2a25897d525d003f6378d1cb23ab305578394694202a58c3", + "sha256:8c8aaad94455178e3187ab22c8b01a3837f8ee50e09cf31f1ba129eb293ec30b", + "sha256:8ce678dbaf790dbdb3eba24056d5364fb45944f33553dd5869b7580cdbb83614", + "sha256:aaee9905aee35ba5905cfb3c62f3e83b3bec7b39413f0a7f19be4e547ea01ebb", + "sha256:bcd3b13b56ea479b3650b82cabd6b5343a625b0ced5429e4ccad28a8973f301b", + "sha256:c9e348e02e4d2b4a8b2eedb48210430658df6951fa484e59de33ff773fbd4b41", + "sha256:d205b1b46085271b4e15f670058ce182bd1199e56b317bf2ec004b6a44f911f6", + "sha256:d43943ef777f9a1c42bf4e552ba23ac77a6351de620aa9acf64ad54933ad4d34", + "sha256:d5d33e9e7af3b34a40dc05f498939f0ebf187f07c385fd58d591c533ad8562fe", + "sha256:fc0fea399acb12edbf8a628ba8d2312f583bdbdb3335635db062fa98cf71fca4", + "sha256:fe460b922ec15dd205595c9b5b99e2f056fd98ae8f9f56b888e7a17dc2b757e7" + ], + "version": "==1.4.1" + }, "urllib3": { "hashes": [ - "sha256:2f3db8b19923a873b3e5256dc9c2dedfa883e33d87c690d9c7913e1f40673cdc", - "sha256:87716c2d2a7121198ebcb7ce7cccf6ce5e9ba539041cfbaeecfb641dc0bf6acc" + "sha256:3018294ebefce6572a474f0604c2021e33b3fd8006ecd11d62107a5d2a963527", + "sha256:88206b0eb87e6d677d424843ac5209e3fb9d0190d0ee169599165ec25e9d9115" ], "markers": "python_version != '3.4'", - "version": "==1.25.8" + "version": "==1.25.9" }, "virtualenv": { "hashes": [ - "sha256:4e399f48c6b71228bf79f5febd27e3bbb753d9d5905776a86667bc61ab628a25", - "sha256:9e81279f4a9d16d1c0654a127c2c86e5bca2073585341691882c1e66e31ef8a5" + "sha256:5021396e8f03d0d002a770da90e31e61159684db2859d0ba4850fbea752aa675", + "sha256:ac53ade75ca189bc97b6c1d9ec0f1a50efe33cbf178ae09452dcd9fd309013c1" ], - "version": "==20.0.15" + "version": "==20.0.18" }, "wcwidth": { "hashes": [ @@ -681,6 +792,14 @@ ], "index": "pypi", "version": "==0.34.2" + }, + "zipp": { + "hashes": [ + "sha256:aa36550ff0c0b7ef7fa639055d797116ee891440eac1a56f378e2d3179e0320b", + "sha256:c599e4d75c98f6798c509911d08a22e6c021d074469042177c8c86fb92eefd96" + ], + "markers": "python_version < '3.8'", + "version": "==3.1.0" } } } diff --git a/README.rst b/README.rst index 05505238..4e5aa3b6 100644 --- a/README.rst +++ b/README.rst @@ -10,9 +10,11 @@ .. image:: https://img.shields.io/pypi/l/PyAthena.svg :target: https://github.com/laughingman7743/PyAthena/blob/master/LICENSE -.. image:: https://img.shields.io/pypi/dm/PyAthena.svg - :target: https://pypistats.org/packages/pyathena +.. image:: https://pepy.tech/badge/pyathena/month + :target: https://pepy.tech/project/pyathena/month +.. image:: https://img.shields.io/badge/code%20style-black-000000.svg + :target: https://github.com/psf/black PyAthena ======== @@ -828,3 +830,25 @@ Run test multiple Python versions $ pyenv local 3.8.2 3.7.2 3.6.8 3.5.7 2.7.16 $ pipenv run tox $ pipenv run scripts/test_data/delete_test_data.sh + +Code formatting +--------------- + +The code formatting uses `black`_ and `isort`_. + +Appy format +~~~~~~~~~~~ + +.. code:: bash + + $ make fmt + +Check format +~~~~~~~~~~~~ + +.. code:: bash + + $ make chk + +.. _`black`: https://github.com/psf/black +.. _`isort`: https://github.com/timothycrosley/isort diff --git a/setup.cfg b/setup.cfg index ad8a0f28..4948b4c3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,3 +3,8 @@ universal = 1 [tool:pytest] flake8-max-line-length = 100 + +[isort] +line_length = 100 +order_by_type = True +multi_line_output = 4 diff --git a/tox.ini b/tox.ini index 7770fdb5..86ee7c4b 100644 --- a/tox.ini +++ b/tox.ini @@ -9,8 +9,10 @@ deps = pytest>=3.5 pytest-cov pytest-flake8>=1.0.1 + pytest-black + pytest-isort pytest-xdist commands = - pytest --cov pyathena --cov-report html --cov-report term --flake8 + pytest --cov pyathena --cov-report html --cov-report term --flake8 --black --isort passenv = AWS_* From ee9952b8c52cdafa9a8b571df9717c8e32e545dd Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sat, 2 May 2020 16:25:44 +0900 Subject: [PATCH 2/6] Apply fmt --- benchmarks/benchmark.py | 62 +-- pyathena/__init__.py | 32 +- pyathena/async_cursor.py | 60 ++- pyathena/async_pandas_cursor.py | 29 +- pyathena/common.py | 193 ++++++---- pyathena/connection.py | 135 ++++--- pyathena/converter.py | 101 ++--- pyathena/cursor.py | 62 ++- pyathena/error.py | 15 +- pyathena/formatter.py | 51 +-- pyathena/model.py | 109 +++--- pyathena/pandas_cursor.py | 64 ++-- pyathena/result_set.py | 170 +++++---- pyathena/sqlalchemy_athena.py | 204 ++++++---- pyathena/util.py | 313 +++++++++------ setup.py | 75 ++-- tests/__init__.py | 27 +- tests/conftest.py | 80 ++-- tests/test_async_cursor.py | 88 +++-- tests/test_async_pandas_cursor.py | 104 +++-- tests/test_cursor.py | 352 +++++++++-------- tests/test_formatter.py | 176 ++++++--- tests/test_model.py | 50 ++- tests/test_pandas_cursor.py | 390 ++++++++++--------- tests/test_sqlalchemy_athena.py | 409 ++++++++++++-------- tests/test_util.py | 607 +++++++++++++++++++----------- tests/util.py | 17 +- 27 files changed, 2369 insertions(+), 1606 deletions(-) diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py index 5a808d33..f4cec451 100755 --- a/benchmarks/benchmark.py +++ b/benchmarks/benchmark.py @@ -5,15 +5,15 @@ import time from pyathena import connect -from pyathenajdbc import connect as jdbc_connect from pyathena.pandas_cursor import PandasCursor +from pyathenajdbc import connect as jdbc_connect LOGGER = logging.getLogger(__name__) LOGGER.addHandler(logging.StreamHandler(sys.stdout)) LOGGER.setLevel(logging.INFO) -S3_STAGING_DIR = 's3://YOUR_BUCKET/path/to/' -REGION_NAME = 'us-west-2' +S3_STAGING_DIR = "s3://YOUR_BUCKET/path/to/" +REGION_NAME = "us-west-2" COUNT = 5 SMALL_RESULT_SET_QUERY = """ @@ -31,44 +31,46 @@ def run_pyathen_pandas_cursor(query): - LOGGER.info('PyAthena PandasCursor =========================') - cursor = connect(s3_staging_dir=S3_STAGING_DIR, - region_name=REGION_NAME, - cursor_class=PandasCursor).cursor() + LOGGER.info("PyAthena PandasCursor =========================") + cursor = connect( + s3_staging_dir=S3_STAGING_DIR, + region_name=REGION_NAME, + cursor_class=PandasCursor, + ).cursor() avgs = [] for i in range(0, COUNT): start = time.time() df = cursor.execute(query).as_pandas() end = time.time() elapsed = end - start - LOGGER.info('loop:{0}\tcount:{1}\telasped:{2}'.format(i, df.shape[0], elapsed)) + LOGGER.info("loop:{0}\tcount:{1}\telasped:{2}".format(i, df.shape[0], elapsed)) avgs.append(elapsed) avg = sum(avgs) / COUNT - LOGGER.info('Avg: {0}'.format(avg)) - LOGGER.info('===============================================') + LOGGER.info("Avg: {0}".format(avg)) + LOGGER.info("===============================================") def run_pyathena_cursor(query): - LOGGER.info('PyAthena Cursor ===============================') - cursor = connect(s3_staging_dir=S3_STAGING_DIR, - region_name=REGION_NAME).cursor() + LOGGER.info("PyAthena Cursor ===============================") + cursor = connect(s3_staging_dir=S3_STAGING_DIR, region_name=REGION_NAME).cursor() avgs = [] for i in range(0, COUNT): start = time.time() result = cursor.execute(query).fetchall() end = time.time() elapsed = end - start - LOGGER.info('loop:{0}\tcount:{1}\telasped:{2}'.format(i, len(result), elapsed)) + LOGGER.info("loop:{0}\tcount:{1}\telasped:{2}".format(i, len(result), elapsed)) avgs.append(elapsed) avg = sum(avgs) / COUNT - LOGGER.info('Avg: {0}'.format(avg)) - LOGGER.info('===============================================') + LOGGER.info("Avg: {0}".format(avg)) + LOGGER.info("===============================================") def run_pyathenajdbc_cursor(query): - LOGGER.info('PyAthenaJDBC Cursor ===========================') - cursor = jdbc_connect(s3_staging_dir=S3_STAGING_DIR, - region_name=REGION_NAME).cursor() + LOGGER.info("PyAthenaJDBC Cursor ===========================") + cursor = jdbc_connect( + s3_staging_dir=S3_STAGING_DIR, region_name=REGION_NAME + ).cursor() avgs = [] for i in range(0, COUNT): start = time.time() @@ -76,25 +78,27 @@ def run_pyathenajdbc_cursor(query): result = cursor.fetchall() end = time.time() elapsed = end - start - LOGGER.info('loop:{0}\tcount:{1}\telasped:{2}'.format(i, len(result), elapsed)) + LOGGER.info("loop:{0}\tcount:{1}\telasped:{2}".format(i, len(result), elapsed)) avgs.append(elapsed) avg = sum(avgs) / COUNT - LOGGER.info('Avg: {0}'.format(avg)) - LOGGER.info('===============================================') + LOGGER.info("Avg: {0}".format(avg)) + LOGGER.info("===============================================") def main(): - for query in [SMALL_RESULT_SET_QUERY, - MEDIUM_RESULT_SET_QUERY, - LARGE_RESULT_SET_QUERY]: + for query in [ + SMALL_RESULT_SET_QUERY, + MEDIUM_RESULT_SET_QUERY, + LARGE_RESULT_SET_QUERY, + ]: LOGGER.info(query) run_pyathenajdbc_cursor(query) - LOGGER.info('') + LOGGER.info("") run_pyathena_cursor(query) - LOGGER.info('') + LOGGER.info("") run_pyathen_pandas_cursor(query) - LOGGER.info('') + LOGGER.info("") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/pyathena/__init__.py b/pyathena/__init__.py index ee1b6cb8..09dbc392 100644 --- a/pyathena/__init__.py +++ b/pyathena/__init__.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import datetime @@ -9,15 +8,17 @@ try: from multiprocessing import cpu_count except ImportError: + def cpu_count(): return None -__version__ = '1.10.5' + +__version__ = "1.10.5" # Globals https://www.python.org/dev/peps/pep-0249/#globals -apilevel = '2.0' +apilevel = "2.0" threadsafety = 3 -paramstyle = 'pyformat' +paramstyle = "pyformat" class DBAPITypeObject: @@ -25,6 +26,7 @@ class DBAPITypeObject: https://www.python.org/dev/peps/pep-0249/#type-objects-and-constructors """ + def __init__(self, *values): self.values = values @@ -41,15 +43,16 @@ def __eq__(self, other): # https://docs.aws.amazon.com/athena/latest/ug/data-types.html -STRING = DBAPITypeObject('char', 'varchar', 'map', 'array', 'row') -BINARY = DBAPITypeObject('varbinary') -BOOLEAN = DBAPITypeObject('boolean') -NUMBER = DBAPITypeObject('tinyint', 'smallint', 'bigint', 'integer', - 'real', 'double', 'float', 'decimal') -DATE = DBAPITypeObject('date') -TIME = DBAPITypeObject('time', 'time with time zone') -DATETIME = DBAPITypeObject('timestamp', 'timestamp with time zone') -JSON = DBAPITypeObject('json') +STRING = DBAPITypeObject("char", "varchar", "map", "array", "row") +BINARY = DBAPITypeObject("varbinary") +BOOLEAN = DBAPITypeObject("boolean") +NUMBER = DBAPITypeObject( + "tinyint", "smallint", "bigint", "integer", "real", "double", "float", "decimal" +) +DATE = DBAPITypeObject("date") +TIME = DBAPITypeObject("time", "time with time zone") +DATETIME = DBAPITypeObject("timestamp", "timestamp with time zone") +JSON = DBAPITypeObject("json") Date = datetime.date Time = datetime.time @@ -58,4 +61,5 @@ def __eq__(self, other): def connect(*args, **kwargs): from pyathena.connection import Connection + return Connection(*args, **kwargs) diff --git a/pyathena/async_cursor.py b/pyathena/async_cursor.py index f3a2a934..81624f2c 100644 --- a/pyathena/async_cursor.py +++ b/pyathena/async_cursor.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import logging from concurrent.futures.thread import ThreadPoolExecutor @@ -15,11 +14,21 @@ class AsyncCursor(BaseCursor): - - def __init__(self, connection, s3_staging_dir, schema_name, work_group, - poll_interval, encryption_option, kms_key, converter, formatter, - retry_config, max_workers=(cpu_count() or 1) * 5, - arraysize=CursorIterator.DEFAULT_FETCH_SIZE): + def __init__( + self, + connection, + s3_staging_dir, + schema_name, + work_group, + poll_interval, + encryption_option, + kms_key, + converter, + formatter, + retry_config, + max_workers=(cpu_count() or 1) * 5, + arraysize=CursorIterator.DEFAULT_FETCH_SIZE, + ): super(AsyncCursor, self).__init__( connection=connection, s3_staging_dir=s3_staging_dir, @@ -30,7 +39,8 @@ def __init__(self, connection, s3_staging_dir, schema_name, work_group, kms_key=kms_key, converter=converter, formatter=formatter, - retry_config=retry_config) + retry_config=retry_config, + ) self._executor = ThreadPoolExecutor(max_workers=max_workers) self._arraysize = arraysize @@ -41,8 +51,11 @@ def arraysize(self): @arraysize.setter def arraysize(self, value): if value <= 0 or value > CursorIterator.DEFAULT_FETCH_SIZE: - raise ProgrammingError('MaxResults is more than maximum allowed length {0}.'.format( - CursorIterator.DEFAULT_FETCH_SIZE)) + raise ProgrammingError( + "MaxResults is more than maximum allowed length {0}.".format( + CursorIterator.DEFAULT_FETCH_SIZE + ) + ) self._arraysize = value def close(self, wait=False): @@ -68,15 +81,24 @@ def _collect_result_set(self, query_id): converter=self._converter, query_execution=query_execution, arraysize=self._arraysize, - retry_config=self._retry_config) - - def execute(self, operation, parameters=None, work_group=None, s3_staging_dir=None, - cache_size=0): - query_id = self._execute(operation, - parameters=parameters, - work_group=work_group, - s3_staging_dir=s3_staging_dir, - cache_size=cache_size) + retry_config=self._retry_config, + ) + + def execute( + self, + operation, + parameters=None, + work_group=None, + s3_staging_dir=None, + cache_size=0, + ): + query_id = self._execute( + operation, + parameters=parameters, + work_group=work_group, + s3_staging_dir=s3_staging_dir, + cache_size=cache_size, + ) return query_id, self._executor.submit(self._collect_result_set, query_id) def executemany(self, operation, seq_of_parameters): diff --git a/pyathena/async_pandas_cursor.py b/pyathena/async_pandas_cursor.py index c5f86683..21e318c0 100644 --- a/pyathena/async_pandas_cursor.py +++ b/pyathena/async_pandas_cursor.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import logging @@ -13,11 +12,21 @@ class AsyncPandasCursor(AsyncCursor): - - def __init__(self, connection, s3_staging_dir, schema_name, work_group, - poll_interval, encryption_option, kms_key, converter, formatter, - retry_config, max_workers=(cpu_count() or 1) * 5, - arraysize=CursorIterator.DEFAULT_FETCH_SIZE): + def __init__( + self, + connection, + s3_staging_dir, + schema_name, + work_group, + poll_interval, + encryption_option, + kms_key, + converter, + formatter, + retry_config, + max_workers=(cpu_count() or 1) * 5, + arraysize=CursorIterator.DEFAULT_FETCH_SIZE, + ): super(AsyncPandasCursor, self).__init__( connection=connection, s3_staging_dir=s3_staging_dir, @@ -30,7 +39,8 @@ def __init__(self, connection, s3_staging_dir, schema_name, work_group, formatter=formatter, retry_config=retry_config, max_workers=max_workers, - arraysize=arraysize) + arraysize=arraysize, + ) def _collect_result_set(self, query_id): query_execution = self._poll(query_id) @@ -39,4 +49,5 @@ def _collect_result_set(self, query_id): converter=self._converter, query_execution=query_execution, arraysize=self._arraysize, - retry_config=self._retry_config) + retry_config=self._retry_config, + ) diff --git a/pyathena/common.py b/pyathena/common.py index c1000f74..944b2b06 100644 --- a/pyathena/common.py +++ b/pyathena/common.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import logging import time @@ -21,7 +20,7 @@ class CursorIterator(with_metaclass(ABCMeta, object)): def __init__(self, **kwargs): super(CursorIterator, self).__init__() - self.arraysize = kwargs.get('arraysize', self.DEFAULT_FETCH_SIZE) + self.arraysize = kwargs.get("arraysize", self.DEFAULT_FETCH_SIZE) self._rownumber = None @property @@ -31,8 +30,11 @@ def arraysize(self): @arraysize.setter def arraysize(self, value): if value <= 0 or value > self.DEFAULT_FETCH_SIZE: - raise ProgrammingError('MaxResults is more than maximum allowed length {0}.'.format( - self.DEFAULT_FETCH_SIZE)) + raise ProgrammingError( + "MaxResults is more than maximum allowed length {0}.".format( + self.DEFAULT_FETCH_SIZE + ) + ) self._arraysize = value @property @@ -70,10 +72,20 @@ def __iter__(self): class BaseCursor(with_metaclass(ABCMeta, object)): - - def __init__(self, connection, s3_staging_dir, schema_name, work_group, - poll_interval, encryption_option, kms_key, converter, formatter, - retry_config, **kwargs): + def __init__( + self, + connection, + s3_staging_dir, + schema_name, + work_group, + poll_interval, + encryption_option, + kms_key, + converter, + formatter, + retry_config, + **kwargs + ): super(BaseCursor, self).__init__(**kwargs) self._connection = connection self._s3_staging_dir = s3_staging_dir @@ -91,14 +103,16 @@ def connection(self): return self._connection def _get_query_execution(self, query_id): - request = {'QueryExecutionId': query_id} + request = {"QueryExecutionId": query_id} try: - response = retry_api_call(self._connection.client.get_query_execution, - config=self._retry_config, - logger=_logger, - **request) + response = retry_api_call( + self._connection.client.get_query_execution, + config=self._retry_config, + logger=_logger, + **request + ) except Exception as e: - _logger.exception('Failed to get query execution.') + _logger.exception("Failed to get query execution.") raise_from(OperationalError(*e.args), e) else: return AthenaQueryExecution(response) @@ -106,53 +120,56 @@ def _get_query_execution(self, query_id): def _poll(self, query_id): while True: query_execution = self._get_query_execution(query_id) - if query_execution.state in [AthenaQueryExecution.STATE_SUCCEEDED, - AthenaQueryExecution.STATE_FAILED, - AthenaQueryExecution.STATE_CANCELLED]: + if query_execution.state in [ + AthenaQueryExecution.STATE_SUCCEEDED, + AthenaQueryExecution.STATE_FAILED, + AthenaQueryExecution.STATE_CANCELLED, + ]: return query_execution else: time.sleep(self._poll_interval) - def _build_start_query_execution_request(self, query, work_group=None, s3_staging_dir=None): + def _build_start_query_execution_request( + self, query, work_group=None, s3_staging_dir=None + ): request = { - 'QueryString': query, - 'QueryExecutionContext': { - 'Database': self._schema_name, - }, - 'ResultConfiguration': {} + "QueryString": query, + "QueryExecutionContext": {"Database": self._schema_name,}, + "ResultConfiguration": {}, } if self._s3_staging_dir or s3_staging_dir: - request['ResultConfiguration'].update({ - 'OutputLocation': s3_staging_dir if s3_staging_dir else self._s3_staging_dir - }) + request["ResultConfiguration"].update( + { + "OutputLocation": s3_staging_dir + if s3_staging_dir + else self._s3_staging_dir + } + ) if self._work_group or work_group: - request.update({ - 'WorkGroup': work_group if work_group else self._work_group - }) + request.update( + {"WorkGroup": work_group if work_group else self._work_group} + ) if self._encryption_option: enc_conf = { - 'EncryptionOption': self._encryption_option, + "EncryptionOption": self._encryption_option, } if self._kms_key: - enc_conf.update({ - 'KmsKey': self._kms_key - }) - request['ResultConfiguration'].update({ - 'EncryptionConfiguration': enc_conf, - }) + enc_conf.update({"KmsKey": self._kms_key}) + request["ResultConfiguration"].update( + {"EncryptionConfiguration": enc_conf,} + ) return request - def _build_list_query_executions_request(self, max_results, work_group, - next_token=None): - request = {'MaxResults': max_results} + def _build_list_query_executions_request( + self, max_results, work_group, next_token=None + ): + request = {"MaxResults": max_results} if self._work_group or work_group: - request.update({ - 'WorkGroup': work_group if work_group else self._work_group - }) + request.update( + {"WorkGroup": work_group if work_group else self._work_group} + ) if next_token: - request.update({ - 'NextToken': next_token - }) + request.update({"NextToken": next_token}) return request def _find_previous_query_id(self, query, work_group, cache_size): @@ -162,56 +179,78 @@ def _find_previous_query_id(self, query, work_group, cache_size): while cache_size > 0: n = min(cache_size, 50) # 50 is max allowed by AWS API cache_size -= n - request = self._build_list_query_executions_request(n, work_group, next_token) - response = retry_api_call(self.connection._client.list_query_executions, - config=self._retry_config, - logger=_logger, - **request) - query_ids = response.get('QueryExecutionIds', None) + request = self._build_list_query_executions_request( + n, work_group, next_token + ) + response = retry_api_call( + self.connection._client.list_query_executions, + config=self._retry_config, + logger=_logger, + **request + ) + query_ids = response.get("QueryExecutionIds", None) if not query_ids: break # no queries left to check - next_token = response.get('NextToken', None) + next_token = response.get("NextToken", None) query_executions = retry_api_call( self.connection._client.batch_get_query_execution, config=self._retry_config, logger=_logger, - QueryExecutionIds=query_ids - ).get('QueryExecutions', []) + QueryExecutionIds=query_ids, + ).get("QueryExecutions", []) for execution in query_executions: if ( - execution['Query'] == query and - execution['Status']['State'] == AthenaQueryExecution.STATE_SUCCEEDED and - execution['StatementType'] == AthenaQueryExecution.STATEMENT_TYPE_DML + execution["Query"] == query + and execution["Status"]["State"] + == AthenaQueryExecution.STATE_SUCCEEDED + and execution["StatementType"] + == AthenaQueryExecution.STATEMENT_TYPE_DML ): - query_id = execution['QueryExecutionId'] + query_id = execution["QueryExecutionId"] break if query_id or next_token is None: break except Exception: - _logger.warning('Failed to check the cache. Moving on without cache.') + _logger.warning("Failed to check the cache. Moving on without cache.") return query_id - def _execute(self, operation, parameters=None, work_group=None, s3_staging_dir=None, - cache_size=0): + def _execute( + self, + operation, + parameters=None, + work_group=None, + s3_staging_dir=None, + cache_size=0, + ): query = self._formatter.format(operation, parameters) _logger.debug(query) - request = self._build_start_query_execution_request(query, work_group, s3_staging_dir) + request = self._build_start_query_execution_request( + query, work_group, s3_staging_dir + ) query_id = self._find_previous_query_id(query, work_group, cache_size) if query_id is None: try: - query_id = retry_api_call(self._connection.client.start_query_execution, - config=self._retry_config, - logger=_logger, - **request).get('QueryExecutionId', None) + query_id = retry_api_call( + self._connection.client.start_query_execution, + config=self._retry_config, + logger=_logger, + **request + ).get("QueryExecutionId", None) except Exception as e: - _logger.exception('Failed to execute query.') + _logger.exception("Failed to execute query.") raise_from(DatabaseError(*e.args), e) return query_id @abstractmethod - def execute(self, operation, parameters=None, work_group=None, s3_staging_dir=None, - cache_size=0): + def execute( + self, + operation, + parameters=None, + work_group=None, + s3_staging_dir=None, + cache_size=0, + ): raise NotImplementedError # pragma: no cover @abstractmethod @@ -223,14 +262,16 @@ def close(self): raise NotImplementedError # pragma: no cover def _cancel(self, query_id): - request = {'QueryExecutionId': query_id} + request = {"QueryExecutionId": query_id} try: - retry_api_call(self._connection.client.stop_query_execution, - config=self._retry_config, - logger=_logger, - **request) + retry_api_call( + self._connection.client.stop_query_execution, + config=self._retry_config, + logger=_logger, + **request + ) except Exception as e: - _logger.exception('Failed to cancel query.') + _logger.exception("Failed to cancel query.") raise_from(OperationalError(*e.args), e) def setinputsizes(self, sizes): diff --git a/pyathena/connection.py b/pyathena/connection.py index ffec6a88..8ba49695 100644 --- a/pyathena/connection.py +++ b/pyathena/connection.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import logging import os @@ -10,7 +9,7 @@ from future.utils import iteritems from pyathena.async_pandas_cursor import AsyncPandasCursor -from pyathena.converter import DefaultTypeConverter, DefaultPandasTypeConverter +from pyathena.converter import DefaultPandasTypeConverter, DefaultTypeConverter from pyathena.cursor import Cursor from pyathena.error import NotSupportedError from pyathena.formatter import DefaultParameterFormatter @@ -22,25 +21,46 @@ class Connection(object): - _ENV_S3_STAGING_DIR = 'AWS_ATHENA_S3_STAGING_DIR' - _ENV_WORK_GROUP = 'AWS_ATHENA_WORK_GROUP' + _ENV_S3_STAGING_DIR = "AWS_ATHENA_S3_STAGING_DIR" + _ENV_WORK_GROUP = "AWS_ATHENA_WORK_GROUP" _SESSION_PASSING_ARGS = [ - 'aws_access_key_id', 'aws_secret_access_key', - 'aws_session_token', 'region_name', - 'botocore_session', 'profile_name', + "aws_access_key_id", + "aws_secret_access_key", + "aws_session_token", + "region_name", + "botocore_session", + "profile_name", ] _CLIENT_PASSING_ARGS = [ - 'aws_access_key_id', 'aws_secret_access_key', - 'aws_session_token', 'config', - 'api_version', 'use_ssl', 'verify', 'endpoint_url', + "aws_access_key_id", + "aws_secret_access_key", + "aws_session_token", + "config", + "api_version", + "use_ssl", + "verify", + "endpoint_url", ] - def __init__(self, s3_staging_dir=None, region_name=None, schema_name='default', - work_group=None, poll_interval=1, encryption_option=None, kms_key=None, - profile_name=None, role_arn=None, - role_session_name='PyAthena-session-{0}'.format(int(time.time())), - duration_seconds=3600, converter=None, formatter=None, - retry_config=None, cursor_class=Cursor, **kwargs): + def __init__( + self, + s3_staging_dir=None, + region_name=None, + schema_name="default", + work_group=None, + poll_interval=1, + encryption_option=None, + kms_key=None, + profile_name=None, + role_arn=None, + role_session_name="PyAthena-session-{0}".format(int(time.time())), + duration_seconds=3600, + converter=None, + formatter=None, + retry_config=None, + cursor_class=Cursor, + **kwargs + ): self._kwargs = kwargs if s3_staging_dir: self.s3_staging_dir = s3_staging_dir @@ -57,55 +77,60 @@ def __init__(self, s3_staging_dir=None, region_name=None, schema_name='default', self.kms_key = kms_key self.profile_name = profile_name - assert self.schema_name, 'Required argument `schema_name` not found.' - assert self.s3_staging_dir or self.work_group,\ - 'Required argument `s3_staging_dir` or `work_group` not found.' + assert self.schema_name, "Required argument `schema_name` not found." + assert ( + self.s3_staging_dir or self.work_group + ), "Required argument `s3_staging_dir` or `work_group` not found." if role_arn: - creds = self._assume_role(self.profile_name, self.region_name, role_arn, - role_session_name, duration_seconds) + creds = self._assume_role( + self.profile_name, + self.region_name, + role_arn, + role_session_name, + duration_seconds, + ) self.profile_name = None - self._kwargs.update({ - 'aws_access_key_id': creds['AccessKeyId'], - 'aws_secret_access_key': creds['SecretAccessKey'], - 'aws_session_token': creds['SessionToken'], - }) - self._session = Session(profile_name=self.profile_name, - **self._session_kwargs) - self._client = self._session.client('athena', region_name=self.region_name, - **self._client_kwargs) + self._kwargs.update( + { + "aws_access_key_id": creds["AccessKeyId"], + "aws_secret_access_key": creds["SecretAccessKey"], + "aws_session_token": creds["SessionToken"], + } + ) + self._session = Session(profile_name=self.profile_name, **self._session_kwargs) + self._client = self._session.client( + "athena", region_name=self.region_name, **self._client_kwargs + ) self._converter = converter self._formatter = formatter if formatter else DefaultParameterFormatter() self._retry_config = retry_config if retry_config else RetryConfig() self.cursor_class = cursor_class - def _assume_role(self, profile_name, region_name, role_arn, - role_session_name, duration_seconds): + def _assume_role( + self, profile_name, region_name, role_arn, role_session_name, duration_seconds + ): # MFA is not supported. If you want to use MFA, create a configuration file. # http://boto3.readthedocs.io/en/latest/guide/configuration.html#assume-role-provider - session = Session(profile_name=profile_name, - **self._session_kwargs) - client = session.client('sts', region_name=region_name, - **self._client_kwargs) + session = Session(profile_name=profile_name, **self._session_kwargs) + client = session.client("sts", region_name=region_name, **self._client_kwargs) response = client.assume_role( RoleArn=role_arn, RoleSessionName=role_session_name, DurationSeconds=duration_seconds, ) - return response['Credentials'] + return response["Credentials"] @property def _session_kwargs(self): return { - k: v for k, v in iteritems(self._kwargs) - if k in self._SESSION_PASSING_ARGS + k: v for k, v in iteritems(self._kwargs) if k in self._SESSION_PASSING_ARGS } @property def _client_kwargs(self): return { - k: v for k, v in iteritems(self._kwargs) - if k in self._CLIENT_PASSING_ARGS + k: v for k, v in iteritems(self._kwargs) if k in self._CLIENT_PASSING_ARGS } @property @@ -129,23 +154,25 @@ def __exit__(self, exc_type, exc_val, exc_tb): def cursor(self, cursor=None, **kwargs): if not cursor: cursor = self.cursor_class - converter = kwargs.pop('converter', self._converter) + converter = kwargs.pop("converter", self._converter) if not converter: if cursor is PandasCursor or cursor is AsyncPandasCursor: converter = DefaultPandasTypeConverter() else: converter = DefaultTypeConverter() - return cursor(connection=self, - s3_staging_dir=kwargs.pop('s3_staging_dir', self.s3_staging_dir), - schema_name=kwargs.pop('schema_name', self.schema_name), - work_group=kwargs.pop('work_group', self.work_group), - poll_interval=kwargs.pop('poll_interval', self.poll_interval), - encryption_option=kwargs.pop('encryption_option', self.encryption_option), - kms_key=kwargs.pop('kms_key', self.kms_key), - converter=converter, - formatter=kwargs.pop('formatter', self._formatter), - retry_config=kwargs.pop('retry_config', self._retry_config), - **kwargs) + return cursor( + connection=self, + s3_staging_dir=kwargs.pop("s3_staging_dir", self.s3_staging_dir), + schema_name=kwargs.pop("schema_name", self.schema_name), + work_group=kwargs.pop("work_group", self.work_group), + poll_interval=kwargs.pop("poll_interval", self.poll_interval), + encryption_option=kwargs.pop("encryption_option", self.encryption_option), + kms_key=kwargs.pop("kms_key", self.kms_key), + converter=converter, + formatter=kwargs.pop("formatter", self._formatter), + retry_config=kwargs.pop("retry_config", self._retry_config), + **kwargs + ) def close(self): pass diff --git a/pyathena/converter.py b/pyathena/converter.py index 079697fd..5973ba1b 100644 --- a/pyathena/converter.py +++ b/pyathena/converter.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import binascii import json @@ -19,19 +18,19 @@ def _to_date(varchar_value): if varchar_value is None: return None - return datetime.strptime(varchar_value, '%Y-%m-%d').date() + return datetime.strptime(varchar_value, "%Y-%m-%d").date() def _to_datetime(varchar_value): if varchar_value is None: return None - return datetime.strptime(varchar_value, '%Y-%m-%d %H:%M:%S.%f') + return datetime.strptime(varchar_value, "%Y-%m-%d %H:%M:%S.%f") def _to_time(varchar_value): if varchar_value is None: return None - return datetime.strptime(varchar_value, '%H:%M:%S.%f').time() + return datetime.strptime(varchar_value, "%H:%M:%S.%f").time() def _to_float(varchar_value): @@ -53,7 +52,7 @@ def _to_decimal(varchar_value): def _to_boolean(varchar_value): - if varchar_value is None or varchar_value == '': + if varchar_value is None or varchar_value == "": return None return bool(strtobool(varchar_value)) @@ -61,7 +60,7 @@ def _to_boolean(varchar_value): def _to_binary(varchar_value): if varchar_value is None: return None - return binascii.a2b_hex(''.join(varchar_value.split(' '))) + return binascii.a2b_hex("".join(varchar_value.split(" "))) def _to_json(varchar_value): @@ -78,37 +77,36 @@ def _to_default(varchar_value): _DEFAULT_CONVERTERS = { - 'boolean': _to_boolean, - 'tinyint': _to_int, - 'smallint': _to_int, - 'integer': _to_int, - 'bigint': _to_int, - 'float': _to_float, - 'real': _to_float, - 'double': _to_float, - 'char': _to_default, - 'varchar': _to_default, - 'string': _to_default, - 'timestamp': _to_datetime, - 'date': _to_date, - 'time': _to_time, - 'varbinary': _to_binary, - 'array': _to_default, - 'map': _to_default, - 'row': _to_default, - 'decimal': _to_decimal, - 'json': _to_json, + "boolean": _to_boolean, + "tinyint": _to_int, + "smallint": _to_int, + "integer": _to_int, + "bigint": _to_int, + "float": _to_float, + "real": _to_float, + "double": _to_float, + "char": _to_default, + "varchar": _to_default, + "string": _to_default, + "timestamp": _to_datetime, + "date": _to_date, + "time": _to_time, + "varbinary": _to_binary, + "array": _to_default, + "map": _to_default, + "row": _to_default, + "decimal": _to_decimal, + "json": _to_json, } _DEFAULT_PANDAS_CONVERTERS = { - 'boolean': _to_boolean, - 'decimal': _to_decimal, - 'varbinary': _to_binary, - 'json': _to_json, + "boolean": _to_boolean, + "decimal": _to_decimal, + "varbinary": _to_binary, + "json": _to_json, } class Converter(with_metaclass(ABCMeta, object)): - def __init__(self, mappings, default=None, types=None): if mappings: self._mappings = mappings @@ -146,10 +144,10 @@ def convert(self, type_, value): class DefaultTypeConverter(Converter): - def __init__(self): super(DefaultTypeConverter, self).__init__( - mappings=deepcopy(_DEFAULT_CONVERTERS), default=_to_default) + mappings=deepcopy(_DEFAULT_CONVERTERS), default=_to_default + ) def convert(self, type_, value): converter = self.get(type_) @@ -157,29 +155,32 @@ def convert(self, type_, value): class DefaultPandasTypeConverter(Converter): - def __init__(self): super(DefaultPandasTypeConverter, self).__init__( - mappings=deepcopy(_DEFAULT_PANDAS_CONVERTERS), default=_to_default, types=self._dtypes) + mappings=deepcopy(_DEFAULT_PANDAS_CONVERTERS), + default=_to_default, + types=self._dtypes, + ) @property def _dtypes(self): - if not hasattr(self, '__dtypes'): + if not hasattr(self, "__dtypes"): import pandas as pd + self.__dtypes = { - 'tinyint': pd.Int64Dtype(), - 'smallint': pd.Int64Dtype(), - 'integer': pd.Int64Dtype(), - 'bigint': pd.Int64Dtype(), - 'float': float, - 'real': float, - 'double': float, - 'char': str, - 'varchar': str, - 'string': str, - 'array': str, - 'map': str, - 'row': str, + "tinyint": pd.Int64Dtype(), + "smallint": pd.Int64Dtype(), + "integer": pd.Int64Dtype(), + "bigint": pd.Int64Dtype(), + "float": float, + "real": float, + "double": float, + "char": str, + "varchar": str, + "string": str, + "array": str, + "map": str, + "row": str, } return self.__dtypes diff --git a/pyathena/cursor.py b/pyathena/cursor.py index f9c7c0dd..6eafa61f 100644 --- a/pyathena/cursor.py +++ b/pyathena/cursor.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import logging @@ -14,10 +13,20 @@ class Cursor(BaseCursor, CursorIterator, WithResultSet): - - def __init__(self, connection, s3_staging_dir, schema_name, work_group, - poll_interval, encryption_option, kms_key, converter, formatter, - retry_config, **kwargs): + def __init__( + self, + connection, + s3_staging_dir, + schema_name, + work_group, + poll_interval, + encryption_option, + kms_key, + converter, + formatter, + retry_config, + **kwargs + ): super(Cursor, self).__init__( connection=connection, s3_staging_dir=s3_staging_dir, @@ -29,7 +38,8 @@ def __init__(self, connection, s3_staging_dir, schema_name, work_group, converter=converter, formatter=formatter, retry_config=retry_config, - **kwargs) + **kwargs + ) @property def rownumber(self): @@ -40,19 +50,31 @@ def close(self): self._result_set.close() @synchronized - def execute(self, operation, parameters=None, work_group=None, s3_staging_dir=None, - cache_size=0): + def execute( + self, + operation, + parameters=None, + work_group=None, + s3_staging_dir=None, + cache_size=0, + ): self._reset_state() - self._query_id = self._execute(operation, - parameters=parameters, - work_group=work_group, - s3_staging_dir=s3_staging_dir, - cache_size=cache_size) + self._query_id = self._execute( + operation, + parameters=parameters, + work_group=work_group, + s3_staging_dir=s3_staging_dir, + cache_size=cache_size, + ) query_execution = self._poll(self._query_id) if query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED: self._result_set = AthenaResultSet( - self._connection, self._converter, query_execution, self.arraysize, - self._retry_config) + self._connection, + self._converter, + query_execution, + self.arraysize, + self._retry_config, + ) else: raise OperationalError(query_execution.state_change_reason) return self @@ -66,23 +88,23 @@ def executemany(self, operation, seq_of_parameters): @synchronized def cancel(self): if not self._query_id: - raise ProgrammingError('QueryExecutionId is none or empty.') + raise ProgrammingError("QueryExecutionId is none or empty.") self._cancel(self._query_id) @synchronized def fetchone(self): if not self.has_result_set: - raise ProgrammingError('No result set.') + raise ProgrammingError("No result set.") return self._result_set.fetchone() @synchronized def fetchmany(self, size=None): if not self.has_result_set: - raise ProgrammingError('No result set.') + raise ProgrammingError("No result set.") return self._result_set.fetchmany(size) @synchronized def fetchall(self): if not self.has_result_set: - raise ProgrammingError('No result set.') + raise ProgrammingError("No result set.") return self._result_set.fetchall() diff --git a/pyathena/error.py b/pyathena/error.py index de76a4d7..18020c7d 100644 --- a/pyathena/error.py +++ b/pyathena/error.py @@ -1,11 +1,16 @@ # -*- coding: utf-8 -*- -from __future__ import absolute_import -from __future__ import unicode_literals - +from __future__ import absolute_import, unicode_literals __all__ = [ - 'Error', 'Warning', 'InterfaceError', 'DatabaseError', 'InternalError', - 'OperationalError', 'ProgrammingError', 'DataError', 'NotSupportedError', + "Error", + "Warning", + "InterfaceError", + "DatabaseError", + "InternalError", + "OperationalError", + "ProgrammingError", + "DataError", + "NotSupportedError", ] diff --git a/pyathena/formatter.py b/pyathena/formatter.py index 6f342ce2..193cfd2b 100644 --- a/pyathena/formatter.py +++ b/pyathena/formatter.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import logging from abc import ABCMeta, abstractmethod @@ -27,16 +26,17 @@ def _escape_hive(val): """HiveParamEscaper https://github.com/dropbox/PyHive/blob/master/pyhive/hive.py""" - return "'{0}'".format(val - .replace('\\', '\\\\') - .replace("'", "\\'") - .replace('\r', '\\r') - .replace('\n', '\\n') - .replace('\t', '\\t')) + return "'{0}'".format( + val.replace("\\", "\\\\") + .replace("'", "\\'") + .replace("\r", "\\r") + .replace("\n", "\\n") + .replace("\t", "\\t") + ) def _format_none(formatter, escaper, val): - return 'null' + return "null" def _format_default(formatter, escaper, val): @@ -48,7 +48,7 @@ def _format_date(formatter, escaper, val): def _format_datetime(formatter, escaper, val): - return "TIMESTAMP '{0}'".format(val.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]) + return "TIMESTAMP '{0}'".format(val.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]) def _format_bool(formatter, escaper, val): @@ -64,18 +64,18 @@ def _format_seq(formatter, escaper, val): for v in val: func = formatter.get(v) formatted = func(formatter, escaper, v) - if not isinstance(formatted, (str, unicode, )): + if not isinstance(formatted, (str, unicode,)): # force string format - if isinstance(formatted, (float, Decimal, )): - formatted = '{0:f}'.format(formatted) + if isinstance(formatted, (float, Decimal,)): + formatted = "{0:f}".format(formatted) else: - formatted = '{0}'.format(formatted) + formatted = "{0}".format(formatted) results.append(formatted) - return '({0})'.format(', '.join(results)) + return "({0})".format(", ".join(results)) def _format_decimal(formatter, escaper, val): - return "DECIMAL {0}".format(escaper('{0:f}'.format(val))) + return "DECIMAL {0}".format(escaper("{0:f}".format(val))) _DEFAULT_FORMATTERS = { @@ -96,7 +96,6 @@ def _format_decimal(formatter, escaper, val): class Formatter(with_metaclass(ABCMeta, object)): - def __init__(self, mappings, default=None): self._mappings = mappings self._default = default @@ -123,17 +122,19 @@ def format(self, operation, parameters=None): class DefaultParameterFormatter(Formatter): - def __init__(self): super(DefaultParameterFormatter, self).__init__( - mappings=deepcopy(_DEFAULT_FORMATTERS), default=None) + mappings=deepcopy(_DEFAULT_FORMATTERS), default=None + ) def format(self, operation, parameters=None): if not operation or not operation.strip(): - raise ProgrammingError('Query is none or empty.') + raise ProgrammingError("Query is none or empty.") operation = operation.strip() - if operation.upper().startswith('SELECT') or operation.upper().startswith('WITH'): + if operation.upper().startswith("SELECT") or operation.upper().startswith( + "WITH" + ): escaper = _escape_presto else: escaper = _escape_hive @@ -144,10 +145,12 @@ def format(self, operation, parameters=None): for k, v in iteritems(parameters): func = self.get(v) if not func: - raise TypeError('{0} is not defined formatter.'.format(type(v))) + raise TypeError("{0} is not defined formatter.".format(type(v))) kwargs.update({k: func(self, escaper, v)}) else: - raise ProgrammingError('Unsupported parameter ' + - '(Support for dict only): {0}'.format(parameters)) + raise ProgrammingError( + "Unsupported parameter " + + "(Support for dict only): {0}".format(parameters) + ) return (operation % kwargs).strip() if kwargs else operation.strip() diff --git a/pyathena/model.py b/pyathena/model.py index baf0f6ee..6c202af8 100644 --- a/pyathena/model.py +++ b/pyathena/model.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import logging @@ -11,53 +10,55 @@ class AthenaQueryExecution(object): - STATE_QUEUED = 'QUEUED' - STATE_RUNNING = 'RUNNING' - STATE_SUCCEEDED = 'SUCCEEDED' - STATE_FAILED = 'FAILED' - STATE_CANCELLED = 'CANCELLED' + STATE_QUEUED = "QUEUED" + STATE_RUNNING = "RUNNING" + STATE_SUCCEEDED = "SUCCEEDED" + STATE_FAILED = "FAILED" + STATE_CANCELLED = "CANCELLED" - STATEMENT_TYPE_DDL = 'DDL' - STATEMENT_TYPE_DML = 'DML' - STATEMENT_TYPE_UTILITY = 'UTILITY' + STATEMENT_TYPE_DDL = "DDL" + STATEMENT_TYPE_DML = "DML" + STATEMENT_TYPE_UTILITY = "UTILITY" def __init__(self, response): - query_execution = response.get('QueryExecution', None) + query_execution = response.get("QueryExecution", None) if not query_execution: - raise DataError('KeyError `QueryExecution`') + raise DataError("KeyError `QueryExecution`") - query_execution_context = query_execution.get('QueryExecutionContext', {}) - self._database = query_execution_context.get('Database', None) + query_execution_context = query_execution.get("QueryExecutionContext", {}) + self._database = query_execution_context.get("Database", None) - self._query_id = query_execution.get('QueryExecutionId', None) + self._query_id = query_execution.get("QueryExecutionId", None) if not self._query_id: - raise DataError('KeyError `QueryExecutionId`') + raise DataError("KeyError `QueryExecutionId`") - self._query = query_execution.get('Query', None) + self._query = query_execution.get("Query", None) if not self._query: - raise DataError('KeyError `Query`') - self._statement_type = query_execution.get('StatementType', None) + raise DataError("KeyError `Query`") + self._statement_type = query_execution.get("StatementType", None) - status = query_execution.get('Status', None) + status = query_execution.get("Status", None) if not status: - raise DataError('KeyError `Status`') - self._state = status.get('State', None) - self._state_change_reason = status.get('StateChangeReason', None) - self._completion_date_time = status.get('CompletionDateTime', None) - self._submission_date_time = status.get('SubmissionDateTime', None) + raise DataError("KeyError `Status`") + self._state = status.get("State", None) + self._state_change_reason = status.get("StateChangeReason", None) + self._completion_date_time = status.get("CompletionDateTime", None) + self._submission_date_time = status.get("SubmissionDateTime", None) - statistics = query_execution.get('Statistics', {}) - self._data_scanned_in_bytes = statistics.get('DataScannedInBytes', None) - self._execution_time_in_millis = statistics.get('EngineExecutionTimeInMillis', None) + statistics = query_execution.get("Statistics", {}) + self._data_scanned_in_bytes = statistics.get("DataScannedInBytes", None) + self._execution_time_in_millis = statistics.get( + "EngineExecutionTimeInMillis", None + ) - result_conf = query_execution.get('ResultConfiguration', {}) - self._output_location = result_conf.get('OutputLocation', None) + result_conf = query_execution.get("ResultConfiguration", {}) + self._output_location = result_conf.get("OutputLocation", None) - encryption_conf = result_conf.get('EncryptionConfiguration', {}) - self._encryption_option = encryption_conf.get('EncryptionOption', None) - self._kms_key = encryption_conf.get('KmsKey', None) + encryption_conf = result_conf.get("EncryptionConfiguration", {}) + self._encryption_option = encryption_conf.get("EncryptionOption", None) + self._kms_key = encryption_conf.get("KmsKey", None) - self._work_group = query_execution.get('WorkGroup', None) + self._work_group = query_execution.get("WorkGroup", None) @property def database(self): @@ -118,31 +119,35 @@ def work_group(self): class AthenaRowFormat(object): - ROW_FORMAT_PARQUET = 'parquet' - ROW_FORMAT_ORC = 'orc' - ROW_FORMAT_CSV = 'csv' - ROW_FORMAT_JSON = 'json' - ROW_FORMAT_AVRO = 'avro' + ROW_FORMAT_PARQUET = "parquet" + ROW_FORMAT_ORC = "orc" + ROW_FORMAT_CSV = "csv" + ROW_FORMAT_JSON = "json" + ROW_FORMAT_AVRO = "avro" @staticmethod def is_valid(value): - return value in [AthenaRowFormat.ROW_FORMAT_PARQUET, - AthenaRowFormat.ROW_FORMAT_ORC, - AthenaRowFormat.ROW_FORMAT_CSV, - AthenaRowFormat.ROW_FORMAT_JSON, - AthenaRowFormat.ROW_FORMAT_AVRO] + return value in [ + AthenaRowFormat.ROW_FORMAT_PARQUET, + AthenaRowFormat.ROW_FORMAT_ORC, + AthenaRowFormat.ROW_FORMAT_CSV, + AthenaRowFormat.ROW_FORMAT_JSON, + AthenaRowFormat.ROW_FORMAT_AVRO, + ] class AthenaCompression(object): - COMPRESSION_SNAPPY = 'snappy' - COMPRESSION_ZLIB = 'zlib' - COMPRESSION_LZO = 'lzo' - COMPRESSION_GZIP = 'gzip' + COMPRESSION_SNAPPY = "snappy" + COMPRESSION_ZLIB = "zlib" + COMPRESSION_LZO = "lzo" + COMPRESSION_GZIP = "gzip" @staticmethod def is_valid(value): - return value in [AthenaCompression.COMPRESSION_SNAPPY, - AthenaCompression.COMPRESSION_ZLIB, - AthenaCompression.COMPRESSION_LZO, - AthenaCompression.COMPRESSION_GZIP] + return value in [ + AthenaCompression.COMPRESSION_SNAPPY, + AthenaCompression.COMPRESSION_ZLIB, + AthenaCompression.COMPRESSION_LZO, + AthenaCompression.COMPRESSION_GZIP, + ] diff --git a/pyathena/pandas_cursor.py b/pyathena/pandas_cursor.py index 317cacdc..4ae881ba 100644 --- a/pyathena/pandas_cursor.py +++ b/pyathena/pandas_cursor.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import logging @@ -16,10 +15,20 @@ class PandasCursor(BaseCursor, CursorIterator, WithResultSet): - - def __init__(self, connection, s3_staging_dir, schema_name, work_group, - poll_interval, encryption_option, kms_key, converter, formatter, - retry_config, **kwargs): + def __init__( + self, + connection, + s3_staging_dir, + schema_name, + work_group, + poll_interval, + encryption_option, + kms_key, + converter, + formatter, + retry_config, + **kwargs + ): super(PandasCursor, self).__init__( connection=connection, s3_staging_dir=s3_staging_dir, @@ -31,7 +40,8 @@ def __init__(self, connection, s3_staging_dir, schema_name, work_group, converter=converter, formatter=formatter, retry_config=retry_config, - **kwargs) + **kwargs + ) @property def rownumber(self): @@ -42,19 +52,31 @@ def close(self): self._result_set.close() @synchronized - def execute(self, operation, parameters=None, work_group=None, s3_staging_dir=None, - cache_size=0): + def execute( + self, + operation, + parameters=None, + work_group=None, + s3_staging_dir=None, + cache_size=0, + ): self._reset_state() - self._query_id = self._execute(operation, - parameters=parameters, - work_group=work_group, - s3_staging_dir=s3_staging_dir, - cache_size=cache_size) + self._query_id = self._execute( + operation, + parameters=parameters, + work_group=work_group, + s3_staging_dir=s3_staging_dir, + cache_size=cache_size, + ) query_execution = self._poll(self._query_id) if query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED: self._result_set = AthenaPandasResultSet( - self._connection, self._converter, query_execution, self.arraysize, - self._retry_config) + self._connection, + self._converter, + query_execution, + self.arraysize, + self._retry_config, + ) else: raise OperationalError(query_execution.state_change_reason) return self @@ -68,29 +90,29 @@ def executemany(self, operation, seq_of_parameters): @synchronized def cancel(self): if not self._query_id: - raise ProgrammingError('QueryExecutionId is none or empty.') + raise ProgrammingError("QueryExecutionId is none or empty.") self._cancel(self._query_id) @synchronized def fetchone(self): if not self.has_result_set: - raise ProgrammingError('No result set.') + raise ProgrammingError("No result set.") return self._result_set.fetchone() @synchronized def fetchmany(self, size=None): if not self.has_result_set: - raise ProgrammingError('No result set.') + raise ProgrammingError("No result set.") return self._result_set.fetchmany(size) @synchronized def fetchall(self): if not self.has_result_set: - raise ProgrammingError('No result set.') + raise ProgrammingError("No result set.") return self._result_set.fetchall() @synchronized def as_pandas(self): if not self.has_result_set: - raise ProgrammingError('No result set.') + raise ProgrammingError("No result set.") return self._result_set.as_pandas() diff --git a/pyathena/result_set.py b/pyathena/result_set.py index f9e860e1..3c74b526 100644 --- a/pyathena/result_set.py +++ b/pyathena/result_set.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import collections import io @@ -12,13 +11,12 @@ from pyathena.common import CursorIterator from pyathena.error import DataError, OperationalError, ProgrammingError from pyathena.model import AthenaQueryExecution -from pyathena.util import retry_api_call, parse_output_location +from pyathena.util import parse_output_location, retry_api_call _logger = logging.getLogger(__name__) class WithResultSet(object): - def __init__(self): super(WithResultSet, self).__init__() self._query_id = None @@ -124,13 +122,12 @@ def work_group(self): class AthenaResultSet(CursorIterator): - def __init__(self, connection, converter, query_execution, arraysize, retry_config): super(AthenaResultSet, self).__init__(arraysize=arraysize) self._connection = connection self._converter = converter self._query_execution = query_execution - assert self._query_execution, 'Required argument `query_execution` not found.' + assert self._query_execution, "Required argument `query_execution` not found." self._retry_config = retry_config self._meta_data = None @@ -203,42 +200,44 @@ def description(self): return None return [ ( - m.get('Name', None), - m.get('Type', None), + m.get("Name", None), + m.get("Type", None), None, None, - m.get('Precision', None), - m.get('Scale', None), - m.get('Nullable', None) + m.get("Precision", None), + m.get("Scale", None), + m.get("Nullable", None), ) for m in self._meta_data ] def __fetch(self, next_token=None): if not self.query_id: - raise ProgrammingError('QueryExecutionId is none or empty.') + raise ProgrammingError("QueryExecutionId is none or empty.") if self.state != AthenaQueryExecution.STATE_SUCCEEDED: - raise ProgrammingError('QueryExecutionState is not SUCCEEDED.') + raise ProgrammingError("QueryExecutionState is not SUCCEEDED.") request = { - 'QueryExecutionId': self.query_id, - 'MaxResults': self._arraysize, + "QueryExecutionId": self.query_id, + "MaxResults": self._arraysize, } if next_token: - request.update({'NextToken': next_token}) + request.update({"NextToken": next_token}) try: - response = retry_api_call(self._connection.client.get_query_results, - config=self._retry_config, - logger=_logger, - **request) + response = retry_api_call( + self._connection.client.get_query_results, + config=self._retry_config, + logger=_logger, + **request + ) except Exception as e: - _logger.exception('Failed to fetch result set.') + _logger.exception("Failed to fetch result set.") raise_from(OperationalError(*e.args), e) else: return response def _fetch(self): if not self._next_token: - raise ProgrammingError('NextToken is none or empty.') + raise ProgrammingError("NextToken is none or empty.") response = self.__fetch(self._next_token) self._process_rows(response) @@ -279,40 +278,49 @@ def fetchall(self): return rows def _process_meta_data(self, response): - result_set = response.get('ResultSet', None) + result_set = response.get("ResultSet", None) if not result_set: - raise DataError('KeyError `ResultSet`') - meta_data = result_set.get('ResultSetMetadata', None) + raise DataError("KeyError `ResultSet`") + meta_data = result_set.get("ResultSetMetadata", None) if not meta_data: - raise DataError('KeyError `ResultSetMetadata`') - column_info = meta_data.get('ColumnInfo', None) + raise DataError("KeyError `ResultSetMetadata`") + column_info = meta_data.get("ColumnInfo", None) if column_info is None: - raise DataError('KeyError `ColumnInfo`') + raise DataError("KeyError `ColumnInfo`") self._meta_data = tuple(column_info) def _process_rows(self, response): - result_set = response.get('ResultSet', None) + result_set = response.get("ResultSet", None) if not result_set: - raise DataError('KeyError `ResultSet`') - rows = result_set.get('Rows', None) + raise DataError("KeyError `ResultSet`") + rows = result_set.get("Rows", None) if rows is None: - raise DataError('KeyError `Rows`') + raise DataError("KeyError `Rows`") processed_rows = [] if len(rows) > 0: - offset = 1 if not self._next_token and self._is_first_row_column_labels(rows) else 0 + offset = ( + 1 + if not self._next_token and self._is_first_row_column_labels(rows) + else 0 + ) processed_rows = [ - tuple([self._converter.convert(meta.get('Type', None), - row.get('VarCharValue', None)) - for meta, row in zip(self._meta_data, rows[i].get('Data', []))]) + tuple( + [ + self._converter.convert( + meta.get("Type", None), row.get("VarCharValue", None) + ) + for meta, row in zip(self._meta_data, rows[i].get("Data", [])) + ] + ) for i in xrange(offset, len(rows)) ] self._rows.extend(processed_rows) - self._next_token = response.get('NextToken', None) + self._next_token = response.get("NextToken", None) def _is_first_row_column_labels(self, rows): - first_row_data = rows[0].get('Data', []) + first_row_data = rows[0].get("Data", []) for meta, data in zip(self._meta_data, first_row_data): - if meta.get('Name', None) != data.get('VarCharValue', None): + if meta.get("Name", None) != data.get("VarCharValue", None): return False return True @@ -338,11 +346,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): class AthenaPandasResultSet(AthenaResultSet): _parse_dates = [ - 'date', - 'time', - 'time with time zone', - 'timestamp', - 'timestamp with time zone', + "date", + "time", + "time with time zone", + "timestamp", + "timestamp with time zone", ] def __init__(self, connection, converter, query_execution, arraysize, retry_config): @@ -351,40 +359,49 @@ def __init__(self, connection, converter, query_execution, arraysize, retry_conf converter=converter, query_execution=query_execution, arraysize=1, # Fetch one row to retrieve metadata - retry_config=retry_config) + retry_config=retry_config, + ) self._arraysize = arraysize self._client = self._connection.session.client( - 's3', region_name=self._connection.region_name, **self._connection._client_kwargs) - if self.state == AthenaQueryExecution.STATE_SUCCEEDED and \ - self.output_location.endswith(('.csv', '.txt')): + "s3", + region_name=self._connection.region_name, + **self._connection._client_kwargs + ) + if ( + self.state == AthenaQueryExecution.STATE_SUCCEEDED + and self.output_location.endswith((".csv", ".txt")) + ): self._df = self._as_pandas() else: import pandas as pd + self._df = pd.DataFrame() self._iterrows = self._df.iterrows() @property def dtypes(self): return { - d[0]: self._converter.types[d[1]] for d in self.description + d[0]: self._converter.types[d[1]] + for d in self.description if d[1] in self._converter.types } @property def converters(self): return { - d[0]: self._converter.mappings[d[1]] for d in self.description + d[0]: self._converter.mappings[d[1]] + for d in self.description if d[1] in self._converter.mappings } @property def parse_dates(self): - return [ - d[0] for d in self.description if d[1] in self._parse_dates - ] + return [d[0] for d in self.description if d[1] in self._parse_dates] def _trunc_date(self, df): - times = [d[0] for d in self.description if d[1] in ('time', 'time with time zone')] + times = [ + d[0] for d in self.description if d[1] in ("time", "time with time zone") + ] if times: df.loc[:, times] = df.loc[:, times].apply(lambda r: r.dt.time) return df @@ -425,38 +442,43 @@ def fetchall(self): def _as_pandas(self): import pandas as pd + if not self.output_location: - raise ProgrammingError('OutputLocation is none or empty.') + raise ProgrammingError("OutputLocation is none or empty.") bucket, key = parse_output_location(self.output_location) try: - response = retry_api_call(self._client.get_object, - config=self._retry_config, - logger=_logger, - Bucket=bucket, - Key=key) + response = retry_api_call( + self._client.get_object, + config=self._retry_config, + logger=_logger, + Bucket=bucket, + Key=key, + ) except Exception as e: - _logger.exception('Failed to download csv.') + _logger.exception("Failed to download csv.") raise_from(OperationalError(*e.args), e) else: - length = response['ContentLength'] + length = response["ContentLength"] if length: - if self.output_location.endswith('.txt'): - sep = '\t' + if self.output_location.endswith(".txt"): + sep = "\t" header = None names = [d[0] for d in self.description] else: # csv format - sep = ',' + sep = "," header = 0 names = None - df = pd.read_csv(io.BytesIO(response['Body'].read()), - sep=sep, - header=header, - names=names, - dtype=self.dtypes, - converters=self.converters, - parse_dates=self.parse_dates, - infer_datetime_format=True, - skip_blank_lines=False) + df = pd.read_csv( + io.BytesIO(response["Body"].read()), + sep=sep, + header=header, + names=names, + dtype=self.dtypes, + converters=self.converters, + parse_dates=self.parse_dates, + infer_datetime_format=True, + skip_blank_lines=False, + ) df = self._trunc_date(df) else: # Allow empty response df = pd.DataFrame() diff --git a/pyathena/sqlalchemy_athena.py b/pyathena/sqlalchemy_athena.py index e7d54a1b..ede2ab85 100644 --- a/pyathena/sqlalchemy_athena.py +++ b/pyathena/sqlalchemy_athena.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import math import numbers @@ -9,13 +8,28 @@ import tenacity from future.utils import raise_from from sqlalchemy import exc, util -from sqlalchemy.engine import reflection, Engine +from sqlalchemy.engine import Engine, reflection from sqlalchemy.engine.default import DefaultDialect from sqlalchemy.exc import NoSuchTableError, OperationalError -from sqlalchemy.sql.compiler import (BIND_PARAMS, BIND_PARAMS_ESC, - IdentifierPreparer, SQLCompiler, DDLCompiler) -from sqlalchemy.sql.sqltypes import (BIGINT, BINARY, BOOLEAN, DATE, DECIMAL, FLOAT, - INTEGER, NULLTYPE, STRINGTYPE, TIMESTAMP) +from sqlalchemy.sql.compiler import ( + BIND_PARAMS, + BIND_PARAMS_ESC, + DDLCompiler, + IdentifierPreparer, + SQLCompiler, +) +from sqlalchemy.sql.sqltypes import ( + BIGINT, + BINARY, + BOOLEAN, + DATE, + DECIMAL, + FLOAT, + INTEGER, + NULLTYPE, + STRINGTYPE, + TIMESTAMP, +) from tenacity import retry_if_exception, stop_after_attempt, wait_exponential import pyathena @@ -25,6 +39,7 @@ class UniversalSet(object): """UniversalSet https://github.com/dropbox/PyHive/blob/master/pyhive/common.py""" + def __contains__(self, item): return True @@ -33,19 +48,19 @@ class AthenaDMLIdentifierPreparer(IdentifierPreparer): """PrestoIdentifierPreparer https://github.com/dropbox/PyHive/blob/master/pyhive/sqlalchemy_presto.py""" + reserved_words = UniversalSet() class AthenaDDLIdentifierPreparer(IdentifierPreparer): - def __init__( - self, - dialect, - initial_quote='`', - final_quote=None, - escape_quote='`', - quote_case_sensitive_collations=True, - omit_schema=False + self, + dialect, + initial_quote="`", + final_quote=None, + escape_quote="`", + quote_case_sensitive_collations=True, + omit_schema=False, ): super(AthenaDDLIdentifierPreparer, self).__init__( dialect=dialect, @@ -53,7 +68,7 @@ def __init__( final_quote=final_quote, escape_quote=escape_quote, quote_case_sensitive_collations=quote_case_sensitive_collations, - omit_schema=omit_schema + omit_schema=omit_schema, ) @@ -63,7 +78,7 @@ class AthenaStatementCompiler(SQLCompiler): https://github.com/dropbox/PyHive/blob/master/pyhive/sqlalchemy_presto.py""" def visit_char_length_func(self, fn, **kw): - return 'length{0}'.format(self.function_argspec(fn, **kw)) + return "length{0}".format(self.function_argspec(fn, **kw)) def visit_textclause(self, textclause, **kw): def do_bindparam(m): @@ -83,14 +98,11 @@ def do_bindparam(m): # un-escape any \:params return BIND_PARAMS_ESC.sub( lambda m: m.group(1), - BIND_PARAMS.sub( - do_bindparam, - self.post_process_text(textclause.text)) + BIND_PARAMS.sub(do_bindparam, self.post_process_text(textclause.text)), ) class AthenaDDLCompiler(DDLCompiler): - @property def preparer(self): return self._preparer @@ -100,29 +112,31 @@ def preparer(self, value): pass def __init__( - self, - dialect, - statement, - bind=None, - schema_translate_map=None, - compile_kwargs=util.immutabledict()): + self, + dialect, + statement, + bind=None, + schema_translate_map=None, + compile_kwargs=util.immutabledict(), + ): self._preparer = AthenaDDLIdentifierPreparer(dialect) super(AthenaDDLCompiler, self).__init__( dialect=dialect, statement=statement, bind=bind, schema_translate_map=schema_translate_map, - compile_kwargs=compile_kwargs) + compile_kwargs=compile_kwargs, + ) def visit_create_table(self, create): table = create.element preparer = self.preparer - text = '\nCREATE EXTERNAL ' - text += 'TABLE ' + preparer.format_table(table) + ' ' - text += '(' + text = "\nCREATE EXTERNAL " + text += "TABLE " + preparer.format_table(table) + " " + text += "(" - separator = '\n' + separator = "\n" for create_column in create.columns: column = create_column.element try: @@ -135,7 +149,8 @@ def visit_create_table(self, create): util.raise_from_cause( exc.CompileError( util.u("(in table '{0}', column '{1}'): {2}").format( - table.description, column.name, ce.args[0]) + table.description, column.name, ce.args[0] + ) ) ) @@ -152,47 +167,54 @@ def visit_create_table(self, create): def post_create_table(self, table): raw_connection = table.bind.raw_connection() # TODO Supports orc, avro, json, csv or tsv format - text = 'STORED AS PARQUET\n' + text = "STORED AS PARQUET\n" - location = raw_connection._kwargs['s3_dir'] if 's3_dir' in raw_connection._kwargs \ + location = ( + raw_connection._kwargs["s3_dir"] + if "s3_dir" in raw_connection._kwargs else raw_connection.s3_staging_dir + ) if not location: - raise exc.CompileError('`s3_dir` or `s3_staging_dir` parameter is required' - ' in the connection string.') + raise exc.CompileError( + "`s3_dir` or `s3_staging_dir` parameter is required" + " in the connection string." + ) text += "LOCATION '{0}{1}/{2}/'\n".format(location, table.schema, table.name) - compression = raw_connection._kwargs.get('compression') + compression = raw_connection._kwargs.get("compression") if compression: - text += "TBLPROPERTIES ('parquet.compress'='{0}')\n".format(compression.upper()) + text += "TBLPROPERTIES ('parquet.compress'='{0}')\n".format( + compression.upper() + ) return text _TYPE_MAPPINGS = { - 'boolean': BOOLEAN, - 'real': FLOAT, - 'float': FLOAT, - 'double': FLOAT, - 'tinyint': INTEGER, - 'smallint': INTEGER, - 'integer': INTEGER, - 'bigint': BIGINT, - 'decimal': DECIMAL, - 'char': STRINGTYPE, - 'varchar': STRINGTYPE, - 'array': STRINGTYPE, - 'row': STRINGTYPE, # StructType - 'varbinary': BINARY, - 'map': STRINGTYPE, - 'date': DATE, - 'timestamp': TIMESTAMP, + "boolean": BOOLEAN, + "real": FLOAT, + "float": FLOAT, + "double": FLOAT, + "tinyint": INTEGER, + "smallint": INTEGER, + "integer": INTEGER, + "bigint": BIGINT, + "decimal": DECIMAL, + "char": STRINGTYPE, + "varchar": STRINGTYPE, + "array": STRINGTYPE, + "row": STRINGTYPE, # StructType + "varbinary": BINARY, + "map": STRINGTYPE, + "date": DATE, + "timestamp": TIMESTAMP, } class AthenaDialect(DefaultDialect): - name = 'awsathena' - driver = 'rest' + name = "awsathena" + driver = "rest" preparer = AthenaDMLIdentifierPreparer statement_compiler = AthenaStatementCompiler ddl_compiler = AthenaDDLCompiler @@ -210,8 +232,9 @@ class AthenaDialect(DefaultDialect): postfetch_lastrowid = False _pattern_data_catlog_exception = re.compile( - r'(((Database|Namespace)\ (?P.+))|(Table\ (?P.+)))\ not\ found\.') - _pattern_column_type = re.compile(r'^([a-zA-Z]+)($|\(.+\)$)') + r"(((Database|Namespace)\ (?P.+))|(Table\ (?P
.+)))\ not\ found\." + ) + _pattern_column_type = re.compile(r"^([a-zA-Z]+)($|\(.+\)$)") @classmethod def dbapi(cls): @@ -228,11 +251,12 @@ def create_connect_args(self, url): # {aws_access_key_id}:{aws_secret_access_key}@athena.{region_name}.amazonaws.com:443/ # {schema_name}?s3_staging_dir={s3_staging_dir}&... opts = { - 'aws_access_key_id': url.username if url.username else None, - 'aws_secret_access_key': url.password if url.password else None, - 'region_name': re.sub(r'^athena\.([a-z0-9-]+)\.amazonaws\.(com|com.cn)$', r'\1', - url.host), - 'schema_name': url.database if url.database else 'default' + "aws_access_key_id": url.username if url.username else None, + "aws_secret_access_key": url.password if url.password else None, + "region_name": re.sub( + r"^athena\.([a-z0-9-]+)\.amazonaws\.(com|com.cn)$", r"\1", url.host + ), + "schema_name": url.database if url.database else "default", } opts.update(url.query) return [[], opts] @@ -254,7 +278,9 @@ def get_table_names(self, connection, schema=None, **kw): SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema}' - """.format(schema=schema) + """.format( + schema=schema + ) return [row.table_name for row in connection.execute(query).fetchall()] def has_table(self, connection, table_name, schema=None): @@ -281,26 +307,39 @@ def get_columns(self, connection, table_name, schema=None, **kw): FROM information_schema.columns WHERE table_schema = '{schema}' AND table_name = '{table}' - """.format(schema=schema, table=table_name) + """.format( + schema=schema, table=table_name + ) retry_config = raw_connection.retry_config retry = tenacity.Retrying( retry=retry_if_exception( - lambda exc: self._retry_if_data_catalog_exception(exc, schema, table_name)), + lambda exc: self._retry_if_data_catalog_exception( + exc, schema, table_name + ) + ), stop=stop_after_attempt(retry_config.attempt), - wait=wait_exponential(multiplier=retry_config.multiplier, - max=retry_config.max_delay, - exp_base=retry_config.exponential_base), - reraise=True) + wait=wait_exponential( + multiplier=retry_config.multiplier, + max=retry_config.max_delay, + exp_base=retry_config.exponential_base, + ), + reraise=True, + ) try: return [ { - 'name': row.column_name, - 'type': _TYPE_MAPPINGS.get(self._get_column_type(row.data_type), NULLTYPE), - 'nullable': True if row.is_nullable == 'YES' else False, - 'default': row.column_default if not self._is_nan(row.column_default) else None, - 'ordinal_position': row.ordinal_position, - 'comment': row.comment, - } for row in retry(connection.execute, query).fetchall() + "name": row.column_name, + "type": _TYPE_MAPPINGS.get( + self._get_column_type(row.data_type), NULLTYPE + ), + "nullable": True if row.is_nullable == "YES" else False, + "default": row.column_default + if not self._is_nan(row.column_default) + else None, + "ordinal_position": row.ordinal_position, + "comment": row.comment, + } + for row in retry(connection.execute, query).fetchall() ] except OperationalError as e: if not self._retry_if_data_catalog_exception(e, schema, table_name): @@ -313,13 +352,14 @@ def _retry_if_data_catalog_exception(self, exc, schema, table_name): return False match = self._pattern_data_catlog_exception.search(str(exc)) - if match and (match.group('schema') == schema or - match.group('table') == table_name): + if match and ( + match.group("schema") == schema or match.group("table") == table_name + ): return False return True def _get_column_type(self, type_): - return self._pattern_column_type.sub(r'\1', type_) + return self._pattern_column_type.sub(r"\1", type_) def get_foreign_keys(self, connection, table_name, schema=None, **kw): # Athena has no support for foreign keys. diff --git a/pyathena/util.py b/pyathena/util.py index c1079843..68f63165 100644 --- a/pyathena/util.py +++ b/pyathena/util.py @@ -1,12 +1,11 @@ # -*- coding: utf-8 -*- -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import concurrent import functools import logging -import threading import re +import threading import uuid from collections import OrderedDict from concurrent.futures.thread import ThreadPoolExecutor @@ -16,23 +15,22 @@ from boto3 import Session from future.utils import iteritems from past.builtins import xrange -from tenacity import (after_log, retry_if_exception, - stop_after_attempt, wait_exponential) +from tenacity import after_log, retry_if_exception, stop_after_attempt, wait_exponential from pyathena import DataError, OperationalError, cpu_count from pyathena.model import AthenaCompression _logger = logging.getLogger(__name__) -PATTERN_OUTPUT_LOCATION = re.compile(r'^s3://(?P[a-zA-Z0-9.\-_]+)/(?P.+)$') +PATTERN_OUTPUT_LOCATION = re.compile(r"^s3://(?P[a-zA-Z0-9.\-_]+)/(?P.+)$") def parse_output_location(output_location): match = PATTERN_OUTPUT_LOCATION.search(output_location) if match: - return match.group('bucket'), match.group('key') + return match.group("bucket"), match.group("key") else: - raise DataError('Unknown `output_location` format.') + raise DataError("Unknown `output_location` format.") def get_chunks(df, chunksize=None): @@ -42,7 +40,7 @@ def get_chunks(df, chunksize=None): if chunksize is None: chunksize = rows elif chunksize <= 0: - raise ValueError('Chunk size argument must be greater than zero') + raise ValueError("Chunk size argument must be greater than zero") chunks = int(rows / chunksize) + 1 for i in xrange(chunks): @@ -54,103 +52,136 @@ def get_chunks(df, chunksize=None): def reset_index(df, index_label=None): - df.index.name = index_label if index_label else 'index' + df.index.name = index_label if index_label else "index" try: df.reset_index(inplace=True) except ValueError as e: - raise ValueError('Duplicate name in index/columns: {0}'.format(e)) + raise ValueError("Duplicate name in index/columns: {0}".format(e)) def as_pandas(cursor, coerce_float=False): from pandas import DataFrame + names = [metadata[0] for metadata in cursor.description] - return DataFrame.from_records(cursor.fetchall(), columns=names, - coerce_float=coerce_float) + return DataFrame.from_records( + cursor.fetchall(), columns=names, coerce_float=coerce_float + ) def to_sql_type_mappings(col): import pandas as pd + col_type = pd._lib.infer_dtype(col, skipna=True) - if col_type == 'datetime64' or col_type == 'datetime': - return 'TIMESTAMP' - elif col_type == 'timedelta': - return 'INT' + if col_type == "datetime64" or col_type == "datetime": + return "TIMESTAMP" + elif col_type == "timedelta": + return "INT" elif col_type == "timedelta64": - return 'BIGINT' - elif col_type == 'floating': - if col.dtype == 'float32': - return 'FLOAT' + return "BIGINT" + elif col_type == "floating": + if col.dtype == "float32": + return "FLOAT" else: - return 'DOUBLE' - elif col_type == 'integer': - if col.dtype == 'int32': - return 'INT' + return "DOUBLE" + elif col_type == "integer": + if col.dtype == "int32": + return "INT" else: - return 'BIGINT' - elif col_type == 'boolean': - return 'BOOLEAN' + return "BIGINT" + elif col_type == "boolean": + return "BOOLEAN" elif col_type == "date": - return 'DATE' - elif col_type == 'bytes': - return 'BINARY' - elif col_type in ['complex', 'time']: - raise ValueError('{0} datatype not supported'.format(col_type)) - return 'STRING' - - -def to_parquet(df, bucket_name, prefix, retry_config, session_kwargs, client_kwargs, - compression=None, flavor='spark'): + return "DATE" + elif col_type == "bytes": + return "BINARY" + elif col_type in ["complex", "time"]: + raise ValueError("{0} datatype not supported".format(col_type)) + return "STRING" + + +def to_parquet( + df, + bucket_name, + prefix, + retry_config, + session_kwargs, + client_kwargs, + compression=None, + flavor="spark", +): import pyarrow as pa import pyarrow.parquet as pq session = Session(**session_kwargs) - client = session.resource('s3', **client_kwargs) + client = session.resource("s3", **client_kwargs) bucket = client.Bucket(bucket_name) table = pa.Table.from_pandas(df) buf = pa.BufferOutputStream() - pq.write_table(table, buf, - compression=compression, - flavor=flavor) - response = retry_api_call(bucket.put_object, - config=retry_config, - Body=buf.getvalue().to_pybytes(), - Key=prefix + str(uuid.uuid4())) - return 's3://{0}/{1}'.format(response.bucket_name, response.key) - - -def to_sql(df, name, conn, location, schema='default', - index=False, index_label=None, partitions=None, chunksize=None, - if_exists='fail', compression=None, flavor='spark', - type_mappings=to_sql_type_mappings, - executor_class=ThreadPoolExecutor, - max_workers=(cpu_count() or 1) * 5): + pq.write_table(table, buf, compression=compression, flavor=flavor) + response = retry_api_call( + bucket.put_object, + config=retry_config, + Body=buf.getvalue().to_pybytes(), + Key=prefix + str(uuid.uuid4()), + ) + return "s3://{0}/{1}".format(response.bucket_name, response.key) + + +def to_sql( + df, + name, + conn, + location, + schema="default", + index=False, + index_label=None, + partitions=None, + chunksize=None, + if_exists="fail", + compression=None, + flavor="spark", + type_mappings=to_sql_type_mappings, + executor_class=ThreadPoolExecutor, + max_workers=(cpu_count() or 1) * 5, +): # TODO Supports orc, avro, json, csv or tsv format - if if_exists not in ('fail', 'replace', 'append'): - raise ValueError('`{0}` is not valid for if_exists'.format(if_exists)) + if if_exists not in ("fail", "replace", "append"): + raise ValueError("`{0}` is not valid for if_exists".format(if_exists)) if compression is not None and not AthenaCompression.is_valid(compression): - raise ValueError('`{0}` is not valid for compression'.format(compression)) + raise ValueError("`{0}` is not valid for compression".format(compression)) if partitions is None: partitions = [] bucket_name, key_prefix = parse_output_location(location) - bucket = conn.session.resource('s3', region_name=conn.region_name, - **conn._client_kwargs).Bucket(bucket_name) + bucket = conn.session.resource( + "s3", region_name=conn.region_name, **conn._client_kwargs + ).Bucket(bucket_name) cursor = conn.cursor() - table = cursor.execute(""" + table = cursor.execute( + """ SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema}' AND table_name = '{table}' - """.format(schema=schema, table=name)).fetchall() - if if_exists == 'fail': + """.format( + schema=schema, table=name + ) + ).fetchall() + if if_exists == "fail": if table: - raise OperationalError('Table `{0}.{1}` already exists.'.format(schema, name)) - elif if_exists == 'replace': + raise OperationalError( + "Table `{0}.{1}` already exists.".format(schema, name) + ) + elif if_exists == "replace": if table: - cursor.execute(""" + cursor.execute( + """ DROP TABLE {schema}.{table} - """.format(schema=schema, table=name)) + """.format( + schema=schema, table=name + ) + ) objects = bucket.objects.filter(Prefix=key_prefix) if list(objects.limit(1)): objects.delete() @@ -160,70 +191,103 @@ def to_sql(df, name, conn, location, schema='default', with executor_class(max_workers=max_workers) as e: futures = [] session_kwargs = deepcopy(conn._session_kwargs) - session_kwargs.update({'profile_name': conn.profile_name}) + session_kwargs.update({"profile_name": conn.profile_name}) client_kwargs = deepcopy(conn._client_kwargs) - client_kwargs.update({'region_name': conn.region_name}) + client_kwargs.update({"region_name": conn.region_name}) if partitions: for keys, group in df.groupby(by=partitions, observed=True): - keys = keys if isinstance(keys, tuple) else (keys, ) + keys = keys if isinstance(keys, tuple) else (keys,) group = group.drop(partitions, axis=1) - partition_prefix = '/'.join(['{0}={1}'.format(key, val) - for key, val in zip(partitions, keys)]) + partition_prefix = "/".join( + ["{0}={1}".format(key, val) for key, val in zip(partitions, keys)] + ) for chunk in get_chunks(group, chunksize): - futures.append(e.submit(to_parquet, chunk, bucket_name, - '{0}{1}/'.format(key_prefix, partition_prefix), - conn._retry_config, session_kwargs, client_kwargs, - compression, flavor)) + futures.append( + e.submit( + to_parquet, + chunk, + bucket_name, + "{0}{1}/".format(key_prefix, partition_prefix), + conn._retry_config, + session_kwargs, + client_kwargs, + compression, + flavor, + ) + ) else: for chunk in get_chunks(df, chunksize): - futures.append(e.submit(to_parquet, chunk, bucket_name, - key_prefix, conn._retry_config, - session_kwargs, client_kwargs, - compression, flavor)) + futures.append( + e.submit( + to_parquet, + chunk, + bucket_name, + key_prefix, + conn._retry_config, + session_kwargs, + client_kwargs, + compression, + flavor, + ) + ) for future in concurrent.futures.as_completed(futures): result = future.result() - _logger.info('to_parquet: {0}'.format(result)) - - ddl = generate_ddl(df=df, - name=name, - location=location, - schema=schema, - partitions=partitions, - compression=compression, - type_mappings=type_mappings) + _logger.info("to_parquet: {0}".format(result)) + + ddl = generate_ddl( + df=df, + name=name, + location=location, + schema=schema, + partitions=partitions, + compression=compression, + type_mappings=type_mappings, + ) _logger.info(ddl) cursor.execute(ddl) if partitions: - repair = 'MSCK REPAIR TABLE {0}.{1}'.format(schema, name) + repair = "MSCK REPAIR TABLE {0}.{1}".format(schema, name) _logger.info(repair) cursor.execute(repair) def get_column_names_and_types(df, type_mappings): - return OrderedDict(( - (str(df.columns[i]), type_mappings(df.iloc[:, i])) - for i in xrange(len(df.columns)) - )) + return OrderedDict( + ( + (str(df.columns[i]), type_mappings(df.iloc[:, i])) + for i in xrange(len(df.columns)) + ) + ) -def generate_ddl(df, name, location, schema='default', partitions=None, compression=None, - type_mappings=to_sql_type_mappings): +def generate_ddl( + df, + name, + location, + schema="default", + partitions=None, + compression=None, + type_mappings=to_sql_type_mappings, +): if partitions is None: partitions = [] column_names_and_types = get_column_names_and_types(df, type_mappings) - ddl = 'CREATE EXTERNAL TABLE IF NOT EXISTS `{0}`.`{1}` (\n'.format(schema, name) - ddl += ',\n'.join([ - '`{0}` {1}'.format(col, type_) - for col, type_ in iteritems(column_names_and_types) if col not in partitions - ]) - ddl += '\n)\n' + ddl = "CREATE EXTERNAL TABLE IF NOT EXISTS `{0}`.`{1}` (\n".format(schema, name) + ddl += ",\n".join( + [ + "`{0}` {1}".format(col, type_) + for col, type_ in iteritems(column_names_and_types) + if col not in partitions + ] + ) + ddl += "\n)\n" if partitions: - ddl += 'PARTITIONED BY (\n' - ddl += ',\n'.join([ - '`{0}` {1}'.format(p, column_names_and_types[p]) for p in partitions - ]) - ddl += '\n)\n' - ddl += 'STORED AS PARQUET\n' + ddl += "PARTITIONED BY (\n" + ddl += ",\n".join( + ["`{0}` {1}".format(p, column_names_and_types[p]) for p in partitions] + ) + ddl += "\n)\n" + ddl += "STORED AS PARQUET\n" ddl += "LOCATION '{0}'\n".format(location) if compression: ddl += "TBLPROPERTIES ('parquet.compress'='{0}')\n".format(compression.upper()) @@ -240,13 +304,19 @@ def synchronized(wrapped): def _wrapper(*args, **kwargs): with _lock: return wrapped(*args, **kwargs) + return _wrapper class RetryConfig(object): - - def __init__(self, exceptions=('ThrottlingException', 'TooManyRequestsException'), - attempt=5, multiplier=1, max_delay=100, exponential_base=2): + def __init__( + self, + exceptions=("ThrottlingException", "TooManyRequestsException"), + attempt=5, + multiplier=1, + max_delay=100, + exponential_base=2, + ): self.exceptions = exceptions self.attempt = attempt self.multiplier = multiplier @@ -254,18 +324,21 @@ def __init__(self, exceptions=('ThrottlingException', 'TooManyRequestsException' self.exponential_base = exponential_base -def retry_api_call(func, config, logger=None, - *args, **kwargs): +def retry_api_call(func, config, logger=None, *args, **kwargs): retry = tenacity.Retrying( retry=retry_if_exception( - lambda e: getattr(e, 'response', {}).get( - 'Error', {}).get('Code', None) in config.exceptions - if e else False), + lambda e: getattr(e, "response", {}).get("Error", {}).get("Code", None) + in config.exceptions + if e + else False + ), stop=stop_after_attempt(config.attempt), - wait=wait_exponential(multiplier=config.multiplier, - max=config.max_delay, - exp_base=config.exponential_base), + wait=wait_exponential( + multiplier=config.multiplier, + max=config.max_delay, + exp_base=config.exponential_base, + ), after=after_log(logger, logger.level) if logger else None, - reraise=True + reraise=True, ) return retry(func, *args, **kwargs) diff --git a/setup.py b/setup.py index 669a5b36..d9220be4 100755 --- a/setup.py +++ b/setup.py @@ -9,65 +9,58 @@ import pyathena -with codecs.open('README.rst', 'rb', 'utf-8') as readme: +with codecs.open("README.rst", "rb", "utf-8") as readme: long_description = readme.read() setup( - name='PyAthena', + name="PyAthena", version=pyathena.__version__, - description='Python DB API 2.0 (PEP 249) compliant client for Amazon Athena', + description="Python DB API 2.0 (PEP 249) compliant client for Amazon Athena", long_description=long_description, - url='https://github.com/laughingman7743/PyAthena/', - author='laughingman7743', - author_email='laughingman7743@gmail.com', - license='MIT License', - packages=find_packages('.', exclude=['tests']), - package_data={ - '': ['LICENSE', '*.rst', 'Pipfile*'], - }, + url="https://github.com/laughingman7743/PyAthena/", + author="laughingman7743", + author_email="laughingman7743@gmail.com", + license="MIT License", + packages=find_packages(".", exclude=["tests"]), + package_data={"": ["LICENSE", "*.rst", "Pipfile*"],}, include_package_data=True, - data_files=[ - ('', ['LICENSE'] + glob('*.rst') + glob('Pipfile*')), - ], + data_files=[("", ["LICENSE"] + glob("*.rst") + glob("Pipfile*")),], install_requires=[ - 'future', + "future", 'futures;python_version=="2.7"', - 'botocore>=1.5.52', - 'boto3>=1.4.4', - 'tenacity>=4.1.0', + "botocore>=1.5.52", + "boto3>=1.4.4", + "tenacity>=4.1.0", ], extras_require={ - 'Pandas': [ - 'pandas>=0.24.0', - 'pyarrow>=0.15.0' - ], - 'SQLAlchemy': ['SQLAlchemy>=1.0.0, <2.0.0'], + "Pandas": ["pandas>=0.24.0", "pyarrow>=0.15.0"], + "SQLAlchemy": ["SQLAlchemy>=1.0.0, <2.0.0"], }, tests_require=[ - 'SQLAlchemy>=1.0.0, <2.0.0', - 'pytest>=3.5', - 'pytest-cov', - 'pytest-flake8>=1.0.1', + "SQLAlchemy>=1.0.0, <2.0.0", + "pytest>=3.5", + "pytest-cov", + "pytest-flake8>=1.0.1", ], entry_points={ - 'sqlalchemy.dialects': [ - 'awsathena.rest = pyathena.sqlalchemy_athena:AthenaDialect', + "sqlalchemy.dialects": [ + "awsathena.rest = pyathena.sqlalchemy_athena:AthenaDialect", ], }, zip_safe=False, classifiers=[ - 'Development Status :: 4 - Beta', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: MIT License', - 'Operating System :: OS Independent', - 'Topic :: Database :: Front-Ends', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Topic :: Database :: Front-Ends", + "Programming Language :: Python :: 2", + "Programming Language :: Python :: 2.7", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", ], ) diff --git a/tests/__init__.py b/tests/__init__.py index c528613c..2854151b 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -4,34 +4,35 @@ import string from past.builtins.misc import xrange - from sqlalchemy.dialects import registry registry.register("awsathena.rest", "pyathena.sqlalchemy_athena", "AthenaDialect") BASE_PATH = os.path.dirname(os.path.abspath(__file__)) -S3_PREFIX = 'test_pyathena' -WORK_GROUP = 'test-pyathena' -SCHEMA = 'test_pyathena_' + ''.join([random.choice( - string.ascii_lowercase + string.digits) for _ in xrange(10)]) +S3_PREFIX = "test_pyathena" +WORK_GROUP = "test-pyathena" +SCHEMA = "test_pyathena_" + "".join( + [random.choice(string.ascii_lowercase + string.digits) for _ in xrange(10)] +) class Env(object): - def __init__(self): - self.region_name = os.getenv('AWS_DEFAULT_REGION', None) - assert self.region_name, \ - 'Required environment variable `AWS_DEFAULT_REGION` not found.' - self.s3_staging_dir = os.getenv('AWS_ATHENA_S3_STAGING_DIR', None) - assert self.s3_staging_dir, \ - 'Required environment variable `AWS_ATHENA_S3_STAGING_DIR` not found.' + self.region_name = os.getenv("AWS_DEFAULT_REGION", None) + assert ( + self.region_name + ), "Required environment variable `AWS_DEFAULT_REGION` not found." + self.s3_staging_dir = os.getenv("AWS_ATHENA_S3_STAGING_DIR", None) + assert ( + self.s3_staging_dir + ), "Required environment variable `AWS_ATHENA_S3_STAGING_DIR` not found." ENV = Env() class WithConnect(object): - def connect(self, work_group=None): from pyathena import connect + return connect(schema_name=SCHEMA, work_group=work_group) diff --git a/tests/conftest.py b/tests/conftest.py index 3bdaf245..aa1e13bf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import contextlib import os @@ -9,11 +8,11 @@ import pytest from pyathena import connect -from tests import BASE_PATH, ENV, SCHEMA, S3_PREFIX +from tests import BASE_PATH, ENV, S3_PREFIX, SCHEMA from tests.util import read_query -@pytest.fixture(scope='session', autouse=True) +@pytest.fixture(scope="session", autouse=True) def _setup_session(request): request.addfinalizer(_teardown_session) with contextlib.closing(connect()) as conn: @@ -29,42 +28,53 @@ def _teardown_session(): def _create_database(cursor): - for q in read_query(os.path.join(BASE_PATH, 'sql', 'create_database.sql')): + for q in read_query(os.path.join(BASE_PATH, "sql", "create_database.sql")): cursor.execute(q.format(schema=SCHEMA)) def _drop_database(cursor): - for q in read_query(os.path.join(BASE_PATH, 'sql', 'drop_database.sql')): + for q in read_query(os.path.join(BASE_PATH, "sql", "drop_database.sql")): cursor.execute(q.format(schema=SCHEMA)) def _create_table(cursor): - location_one_row = '{0}{1}/{2}/'.format( - ENV.s3_staging_dir, S3_PREFIX, 'one_row') - location_many_rows = '{0}{1}/{2}/'.format( - ENV.s3_staging_dir, S3_PREFIX, 'many_rows') - location_one_row_complex = '{0}{1}/{2}/'.format( - ENV.s3_staging_dir, S3_PREFIX, 'one_row_complex') - location_partition_table = '{0}{1}/{2}/'.format( - ENV.s3_staging_dir, S3_PREFIX, 'partition_table') - location_integer_na_values = '{0}{1}/{2}/'.format( - ENV.s3_staging_dir, S3_PREFIX, 'integer_na_values') - location_boolean_na_values = '{0}{1}/{2}/'.format( - ENV.s3_staging_dir, S3_PREFIX, 'boolean_na_values') - location_execute_many = '{0}{1}/{2}/'.format( - ENV.s3_staging_dir, S3_PREFIX, 'execute_many_{0}'.format( - str(uuid.uuid4()).replace('-', ''))) - location_execute_many_pandas = '{0}{1}/{2}/'.format( - ENV.s3_staging_dir, S3_PREFIX, 'execute_many_pandas_{0}'.format( - str(uuid.uuid4()).replace('-', ''))) - for q in read_query( - os.path.join(BASE_PATH, 'sql', 'create_table.sql')): - cursor.execute(q.format(schema=SCHEMA, - location_one_row=location_one_row, - location_many_rows=location_many_rows, - location_one_row_complex=location_one_row_complex, - location_partition_table=location_partition_table, - location_integer_na_values=location_integer_na_values, - location_boolean_na_values=location_boolean_na_values, - location_execute_many=location_execute_many, - location_execute_many_pandas=location_execute_many_pandas)) + location_one_row = "{0}{1}/{2}/".format(ENV.s3_staging_dir, S3_PREFIX, "one_row") + location_many_rows = "{0}{1}/{2}/".format( + ENV.s3_staging_dir, S3_PREFIX, "many_rows" + ) + location_one_row_complex = "{0}{1}/{2}/".format( + ENV.s3_staging_dir, S3_PREFIX, "one_row_complex" + ) + location_partition_table = "{0}{1}/{2}/".format( + ENV.s3_staging_dir, S3_PREFIX, "partition_table" + ) + location_integer_na_values = "{0}{1}/{2}/".format( + ENV.s3_staging_dir, S3_PREFIX, "integer_na_values" + ) + location_boolean_na_values = "{0}{1}/{2}/".format( + ENV.s3_staging_dir, S3_PREFIX, "boolean_na_values" + ) + location_execute_many = "{0}{1}/{2}/".format( + ENV.s3_staging_dir, + S3_PREFIX, + "execute_many_{0}".format(str(uuid.uuid4()).replace("-", "")), + ) + location_execute_many_pandas = "{0}{1}/{2}/".format( + ENV.s3_staging_dir, + S3_PREFIX, + "execute_many_pandas_{0}".format(str(uuid.uuid4()).replace("-", "")), + ) + for q in read_query(os.path.join(BASE_PATH, "sql", "create_table.sql")): + cursor.execute( + q.format( + schema=SCHEMA, + location_one_row=location_one_row, + location_many_rows=location_many_rows, + location_one_row_complex=location_one_row_complex, + location_partition_table=location_partition_table, + location_integer_na_values=location_integer_na_values, + location_boolean_na_values=location_boolean_na_values, + location_execute_many=location_execute_many, + location_execute_many_pandas=location_execute_many_pandas, + ) + ) diff --git a/tests/test_async_cursor.py b/tests/test_async_cursor.py index dc27903d..04b859f0 100644 --- a/tests/test_async_cursor.py +++ b/tests/test_async_cursor.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import contextlib import time @@ -20,10 +19,9 @@ class TestAsyncCursor(unittest.TestCase, WithConnect): - @with_async_cursor() def test_fetchone(self, cursor): - query_id, future = cursor.execute('SELECT * FROM one_row') + query_id, future = cursor.execute("SELECT * FROM one_row") result_set = future.result() self.assertEqual(result_set.rownumber, 0) self.assertEqual(result_set.fetchone(), (1,)) @@ -43,23 +41,23 @@ def test_fetchone(self, cursor): @with_async_cursor() def test_fetchmany(self, cursor): - query_id, future = cursor.execute('SELECT * FROM many_rows LIMIT 15') + query_id, future = cursor.execute("SELECT * FROM many_rows LIMIT 15") result_set = future.result() self.assertEqual(len(result_set.fetchmany(10)), 10) self.assertEqual(len(result_set.fetchmany(10)), 5) @with_async_cursor() def test_fetchall(self, cursor): - query_id, future = cursor.execute('SELECT * FROM one_row') + query_id, future = cursor.execute("SELECT * FROM one_row") result_set = future.result() self.assertEqual(result_set.fetchall(), [(1,)]) - query_id, future = cursor.execute('SELECT a FROM many_rows ORDER BY a') + query_id, future = cursor.execute("SELECT a FROM many_rows ORDER BY a") result_set = future.result() self.assertEqual(result_set.fetchall(), [(i,) for i in xrange(10000)]) @with_async_cursor() def test_iterator(self, cursor): - query_id, future = cursor.execute('SELECT * FROM one_row') + query_id, future = cursor.execute("SELECT * FROM one_row") result_set = future.result() self.assertEqual(list(result_set), [(1,)]) self.assertRaises(StopIteration, result_set.__next__) @@ -67,7 +65,7 @@ def test_iterator(self, cursor): @with_async_cursor() def test_arraysize(self, cursor): cursor.arraysize = 5 - query_id, future = cursor.execute('SELECT * FROM many_rows LIMIT 20') + query_id, future = cursor.execute("SELECT * FROM many_rows LIMIT 20") result_set = future.result() self.assertEqual(len(result_set.fetchmany()), 5) @@ -84,10 +82,12 @@ def test_invalid_arraysize(self, cursor): @with_async_cursor() def test_description(self, cursor): - query_id, future = cursor.execute('SELECT 1 AS foobar FROM one_row') + query_id, future = cursor.execute("SELECT 1 AS foobar FROM one_row") result_set = future.result() - self.assertEqual(result_set.description, - [('foobar', 'integer', None, None, 10, 0, 'UNKNOWN')]) + self.assertEqual( + result_set.description, + [("foobar", "integer", None, None, 10, 0, "UNKNOWN")], + ) future = cursor.description(query_id) description = future.result() @@ -95,7 +95,7 @@ def test_description(self, cursor): @with_async_cursor() def test_query_execution(self, cursor): - query = 'SELECT * FROM one_row' + query = "SELECT * FROM one_row" query_id, future = cursor.execute(query) result_set = future.result() @@ -105,7 +105,9 @@ def test_query_execution(self, cursor): self.assertEqual(query_execution.database, SCHEMA) self.assertIsNotNone(query_execution.query_id) self.assertEqual(query_execution.query, query) - self.assertEqual(query_execution.statement_type, AthenaQueryExecution.STATEMENT_TYPE_DML) + self.assertEqual( + query_execution.statement_type, AthenaQueryExecution.STATEMENT_TYPE_DML + ) self.assertEqual(query_execution.state, AthenaQueryExecution.STATE_SUCCEEDED) self.assertIsNone(query_execution.state_change_reason) self.assertIsNotNone(query_execution.completion_date_time) @@ -117,21 +119,33 @@ def test_query_execution(self, cursor): self.assertIsNotNone(query_execution.output_location) self.assertIsNone(query_execution.encryption_option) self.assertIsNone(query_execution.kms_key) - self.assertEqual(query_execution.work_group, 'primary') + self.assertEqual(query_execution.work_group, "primary") self.assertEqual(result_set.database, query_execution.database) self.assertEqual(result_set.query_id, query_execution.query_id) self.assertEqual(result_set.query, query_execution.query) self.assertEqual(result_set.statement_type, query_execution.statement_type) self.assertEqual(result_set.state, query_execution.state) - self.assertEqual(result_set.state_change_reason, query_execution.state_change_reason) - self.assertEqual(result_set.completion_date_time, query_execution.completion_date_time) - self.assertEqual(result_set.submission_date_time, query_execution.submission_date_time) - self.assertEqual(result_set.data_scanned_in_bytes, query_execution.data_scanned_in_bytes) - self.assertEqual(result_set.execution_time_in_millis, - query_execution.execution_time_in_millis) + self.assertEqual( + result_set.state_change_reason, query_execution.state_change_reason + ) + self.assertEqual( + result_set.completion_date_time, query_execution.completion_date_time + ) + self.assertEqual( + result_set.submission_date_time, query_execution.submission_date_time + ) + self.assertEqual( + result_set.data_scanned_in_bytes, query_execution.data_scanned_in_bytes + ) + self.assertEqual( + result_set.execution_time_in_millis, + query_execution.execution_time_in_millis, + ) self.assertEqual(result_set.output_location, query_execution.output_location) - self.assertEqual(result_set.encryption_option, query_execution.encryption_option) + self.assertEqual( + result_set.encryption_option, query_execution.encryption_option + ) self.assertEqual(result_set.kms_key, query_execution.kms_key) self.assertEqual(result_set.work_group, query_execution.work_group) @@ -140,26 +154,35 @@ def test_poll(self, cursor): query_id, _ = cursor.execute("SELECT * FROM one_row") future = cursor.poll(query_id) query_execution = future.result() - self.assertIn(query_execution.state, [AthenaQueryExecution.STATE_QUEUED, - AthenaQueryExecution.STATE_RUNNING, - AthenaQueryExecution.STATE_SUCCEEDED, - AthenaQueryExecution.STATE_FAILED, - AthenaQueryExecution.STATE_CANCELLED]) + self.assertIn( + query_execution.state, + [ + AthenaQueryExecution.STATE_QUEUED, + AthenaQueryExecution.STATE_RUNNING, + AthenaQueryExecution.STATE_SUCCEEDED, + AthenaQueryExecution.STATE_FAILED, + AthenaQueryExecution.STATE_CANCELLED, + ], + ) @with_async_cursor() def test_bad_query(self, cursor): - query_id, future = cursor.execute('SELECT does_not_exist FROM this_really_does_not_exist') + query_id, future = cursor.execute( + "SELECT does_not_exist FROM this_really_does_not_exist" + ) result_set = future.result() self.assertEqual(result_set.state, AthenaQueryExecution.STATE_FAILED) self.assertIsNotNone(result_set.state_change_reason) @with_async_cursor() def test_cancel(self, cursor): - query_id, future = cursor.execute(""" + query_id, future = cursor.execute( + """ SELECT a.a * rand(), b.a * rand() FROM many_rows a CROSS JOIN many_rows b - """) + """ + ) time.sleep(randint(1, 5)) cursor.cancel(query_id) result_set = future.result() @@ -178,7 +201,8 @@ def test_open_close(self): def test_no_ops(self): conn = self.connect() cursor = conn.cursor(AsyncCursor) - self.assertRaises(NotSupportedError, lambda: cursor.executemany( - 'SELECT * FROM one_row', [])) + self.assertRaises( + NotSupportedError, lambda: cursor.executemany("SELECT * FROM one_row", []) + ) cursor.close() conn.close() diff --git a/tests/test_async_pandas_cursor.py b/tests/test_async_pandas_cursor.py index 66ca5ef3..53ecbbf5 100644 --- a/tests/test_async_pandas_cursor.py +++ b/tests/test_async_pandas_cursor.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import contextlib import random @@ -22,10 +21,9 @@ class TestAsyncCursor(unittest.TestCase, WithConnect): - @with_async_pandas_cursor() def test_fetchone(self, cursor): - query_id, future = cursor.execute('SELECT * FROM one_row') + query_id, future = cursor.execute("SELECT * FROM one_row") result_set = future.result() self.assertEqual(result_set.rownumber, 0) self.assertEqual(result_set.fetchone(), (1,)) @@ -45,23 +43,23 @@ def test_fetchone(self, cursor): @with_async_pandas_cursor() def test_fetchmany(self, cursor): - query_id, future = cursor.execute('SELECT * FROM many_rows LIMIT 15') + query_id, future = cursor.execute("SELECT * FROM many_rows LIMIT 15") result_set = future.result() self.assertEqual(len(result_set.fetchmany(10)), 10) self.assertEqual(len(result_set.fetchmany(10)), 5) @with_async_pandas_cursor() def test_fetchall(self, cursor): - query_id, future = cursor.execute('SELECT * FROM one_row') + query_id, future = cursor.execute("SELECT * FROM one_row") result_set = future.result() self.assertEqual(result_set.fetchall(), [(1,)]) - query_id, future = cursor.execute('SELECT a FROM many_rows ORDER BY a') + query_id, future = cursor.execute("SELECT a FROM many_rows ORDER BY a") result_set = future.result() self.assertEqual(result_set.fetchall(), [(i,) for i in xrange(10000)]) @with_async_pandas_cursor() def test_iterator(self, cursor): - query_id, future = cursor.execute('SELECT * FROM one_row') + query_id, future = cursor.execute("SELECT * FROM one_row") result_set = future.result() self.assertEqual(list(result_set), [(1,)]) self.assertRaises(StopIteration, result_set.__next__) @@ -69,7 +67,7 @@ def test_iterator(self, cursor): @with_async_pandas_cursor() def test_arraysize(self, cursor): cursor.arraysize = 5 - query_id, future = cursor.execute('SELECT * FROM many_rows LIMIT 20') + query_id, future = cursor.execute("SELECT * FROM many_rows LIMIT 20") result_set = future.result() self.assertEqual(len(result_set.fetchmany()), 5) @@ -86,10 +84,12 @@ def test_invalid_arraysize(self, cursor): @with_async_pandas_cursor() def test_description(self, cursor): - query_id, future = cursor.execute('SELECT 1 AS foobar FROM one_row') + query_id, future = cursor.execute("SELECT 1 AS foobar FROM one_row") result_set = future.result() - self.assertEqual(result_set.description, - [('foobar', 'integer', None, None, 10, 0, 'UNKNOWN')]) + self.assertEqual( + result_set.description, + [("foobar", "integer", None, None, 10, 0, "UNKNOWN")], + ) future = cursor.description(query_id) description = future.result() @@ -97,7 +97,7 @@ def test_description(self, cursor): @with_async_pandas_cursor() def test_query_execution(self, cursor): - query = 'SELECT * FROM one_row' + query = "SELECT * FROM one_row" query_id, future = cursor.execute(query) result_set = future.result() @@ -119,12 +119,22 @@ def test_query_execution(self, cursor): self.assertEqual(result_set.query_id, query_execution.query_id) self.assertEqual(result_set.query, query_execution.query) self.assertEqual(result_set.state, query_execution.state) - self.assertEqual(result_set.state_change_reason, query_execution.state_change_reason) - self.assertEqual(result_set.completion_date_time, query_execution.completion_date_time) - self.assertEqual(result_set.submission_date_time, query_execution.submission_date_time) - self.assertEqual(result_set.data_scanned_in_bytes, query_execution.data_scanned_in_bytes) - self.assertEqual(result_set.execution_time_in_millis, - query_execution.execution_time_in_millis) + self.assertEqual( + result_set.state_change_reason, query_execution.state_change_reason + ) + self.assertEqual( + result_set.completion_date_time, query_execution.completion_date_time + ) + self.assertEqual( + result_set.submission_date_time, query_execution.submission_date_time + ) + self.assertEqual( + result_set.data_scanned_in_bytes, query_execution.data_scanned_in_bytes + ) + self.assertEqual( + result_set.execution_time_in_millis, + query_execution.execution_time_in_millis, + ) self.assertEqual(result_set.output_location, query_execution.output_location) @with_async_pandas_cursor() @@ -132,43 +142,53 @@ def test_poll(self, cursor): query_id, _ = cursor.execute("SELECT * FROM one_row") future = cursor.poll(query_id) query_execution = future.result() - self.assertIn(query_execution.state, [AthenaQueryExecution.STATE_QUEUED, - AthenaQueryExecution.STATE_RUNNING, - AthenaQueryExecution.STATE_SUCCEEDED, - AthenaQueryExecution.STATE_FAILED, - AthenaQueryExecution.STATE_CANCELLED]) + self.assertIn( + query_execution.state, + [ + AthenaQueryExecution.STATE_QUEUED, + AthenaQueryExecution.STATE_RUNNING, + AthenaQueryExecution.STATE_SUCCEEDED, + AthenaQueryExecution.STATE_FAILED, + AthenaQueryExecution.STATE_CANCELLED, + ], + ) @with_async_pandas_cursor() def test_bad_query(self, cursor): - query_id, future = cursor.execute('SELECT does_not_exist FROM this_really_does_not_exist') + query_id, future = cursor.execute( + "SELECT does_not_exist FROM this_really_does_not_exist" + ) result_set = future.result() self.assertEqual(result_set.state, AthenaQueryExecution.STATE_FAILED) self.assertIsNotNone(result_set.state_change_reason) @with_async_pandas_cursor() def test_as_pandas(self, cursor): - query_id, future = cursor.execute('SELECT * FROM one_row') + query_id, future = cursor.execute("SELECT * FROM one_row") df = future.result().as_pandas() self.assertEqual(df.shape[0], 1) self.assertEqual(df.shape[1], 1) - self.assertEqual([(row['number_of_rows'],) for _, row in df.iterrows()], [(1,)]) + self.assertEqual([(row["number_of_rows"],) for _, row in df.iterrows()], [(1,)]) @with_async_pandas_cursor() def test_many_as_pandas(self, cursor): - query_id, future = cursor.execute('SELECT * FROM many_rows') + query_id, future = cursor.execute("SELECT * FROM many_rows") df = future.result().as_pandas() self.assertEqual(df.shape[0], 10000) self.assertEqual(df.shape[1], 1) - self.assertEqual([(row['a'],) for _, row in df.iterrows()], - [(i,) for i in xrange(10000)]) + self.assertEqual( + [(row["a"],) for _, row in df.iterrows()], [(i,) for i in xrange(10000)] + ) @with_async_pandas_cursor() def test_cancel(self, cursor): - query_id, future = cursor.execute(""" + query_id, future = cursor.execute( + """ SELECT a.a * rand(), b.a * rand() FROM many_rows a CROSS JOIN many_rows b - """) + """ + ) time.sleep(randint(1, 5)) cursor.cancel(query_id) result_set = future.result() @@ -187,23 +207,29 @@ def test_open_close(self): def test_no_ops(self): conn = self.connect() cursor = conn.cursor(AsyncCursor) - self.assertRaises(NotSupportedError, lambda: cursor.executemany( - 'SELECT * FROM one_row', [])) + self.assertRaises( + NotSupportedError, lambda: cursor.executemany("SELECT * FROM one_row", []) + ) cursor.close() conn.close() @with_async_pandas_cursor() def test_empty_result(self, cursor): - table = 'test_pandas_cursor_empty_result_' + ''.join([random.choice( - string.ascii_lowercase + string.digits) for _ in xrange(10)]) - location = '{0}{1}/{2}/'.format(ENV.s3_staging_dir, S3_PREFIX, table) - query_id, future = cursor.execute(""" + table = "test_pandas_cursor_empty_result_" + "".join( + [random.choice(string.ascii_lowercase + string.digits) for _ in xrange(10)] + ) + location = "{0}{1}/{2}/".format(ENV.s3_staging_dir, S3_PREFIX, table) + query_id, future = cursor.execute( + """ CREATE EXTERNAL TABLE IF NOT EXISTS {schema}.{table} (number_of_rows INT) ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' LINES TERMINATED BY '\n' STORED AS TEXTFILE LOCATION '{location}' - """.format(schema=SCHEMA, table=table, location=location)) + """.format( + schema=SCHEMA, table=table, location=location + ) + ) df = future.result().as_pandas() self.assertEqual(df.shape[0], 0) self.assertEqual(df.shape[1], 0) diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 4fd56697..229417c9 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import contextlib import re @@ -14,11 +13,21 @@ from past.builtins.misc import xrange -from pyathena import BINARY, BOOLEAN, DATE, DATETIME, JSON, NUMBER, STRING, TIME, connect +from pyathena import ( + BINARY, + BOOLEAN, + DATE, + DATETIME, + JSON, + NUMBER, + STRING, + TIME, + connect, +) from pyathena.cursor import Cursor from pyathena.error import DatabaseError, NotSupportedError, ProgrammingError from pyathena.model import AthenaQueryExecution -from tests import WithConnect, SCHEMA, ENV, S3_PREFIX, WORK_GROUP +from tests import ENV, S3_PREFIX, SCHEMA, WORK_GROUP, WithConnect from tests.util import with_cursor @@ -32,7 +41,7 @@ class TestCursor(unittest.TestCase, WithConnect): @with_cursor() def test_fetchone(self, cursor): - cursor.execute('SELECT * FROM one_row') + cursor.execute("SELECT * FROM one_row") self.assertEqual(cursor.rownumber, 0) self.assertEqual(cursor.fetchone(), (1,)) self.assertEqual(cursor.rownumber, 1) @@ -52,24 +61,24 @@ def test_fetchone(self, cursor): self.assertIsNotNone(cursor.output_location) self.assertIsNone(cursor.encryption_option) self.assertIsNone(cursor.kms_key) - self.assertEqual(cursor.work_group, 'primary') + self.assertEqual(cursor.work_group, "primary") @with_cursor() def test_fetchmany(self, cursor): - cursor.execute('SELECT * FROM many_rows LIMIT 15') + cursor.execute("SELECT * FROM many_rows LIMIT 15") self.assertEqual(len(cursor.fetchmany(10)), 10) self.assertEqual(len(cursor.fetchmany(10)), 5) @with_cursor() def test_fetchall(self, cursor): - cursor.execute('SELECT * FROM one_row') + cursor.execute("SELECT * FROM one_row") self.assertEqual(cursor.fetchall(), [(1,)]) - cursor.execute('SELECT a FROM many_rows ORDER BY a') + cursor.execute("SELECT a FROM many_rows ORDER BY a") self.assertEqual(cursor.fetchall(), [(i,) for i in xrange(10000)]) @with_cursor() def test_iterator(self, cursor): - cursor.execute('SELECT * FROM one_row') + cursor.execute("SELECT * FROM one_row") self.assertEqual(list(cursor), [(1,)]) self.assertRaises(StopIteration, cursor.__next__) @@ -77,7 +86,7 @@ def test_iterator(self, cursor): def test_cache_size(self, cursor): # To test caching, we need to make sure the query is unique, otherwise # we might accidentally pick up the cache results from another CI run. - query = 'SELECT * FROM one_row -- {0}'.format(str(datetime.utcnow())) + query = "SELECT * FROM one_row -- {0}".format(str(datetime.utcnow())) cursor.execute(query) first_query_id = cursor.query_id @@ -97,13 +106,13 @@ def test_cache_size(self, cursor): @with_cursor(work_group=WORK_GROUP) def test_cache_size_with_work_group(self, cursor): now = datetime.utcnow() - cursor.execute('SELECT %(now)s as date', {'now': now}) + cursor.execute("SELECT %(now)s as date", {"now": now}) first_query_id = cursor.query_id - cursor.execute('SELECT %(now)s as date', {'now': now}) + cursor.execute("SELECT %(now)s as date", {"now": now}) second_query_id = cursor.query_id - cursor.execute('SELECT %(now)s as date', {'now': now}, cache_size=100) + cursor.execute("SELECT %(now)s as date", {"now": now}, cache_size=100) third_query_id = cursor.query_id self.assertNotEqual(first_query_id, second_query_id) @@ -112,7 +121,7 @@ def test_cache_size_with_work_group(self, cursor): @with_cursor() def test_arraysize(self, cursor): cursor.arraysize = 5 - cursor.execute('SELECT * FROM many_rows LIMIT 20') + cursor.execute("SELECT * FROM many_rows LIMIT 20") self.assertEqual(len(cursor.fetchmany()), 5) @with_cursor() @@ -128,9 +137,10 @@ def test_invalid_arraysize(self, cursor): @with_cursor() def test_description(self, cursor): - cursor.execute('SELECT 1 AS foobar FROM one_row') - self.assertEqual(cursor.description, - [('foobar', 'integer', None, None, 10, 0, 'UNKNOWN')]) + cursor.execute("SELECT 1 AS foobar FROM one_row") + self.assertEqual( + cursor.description, [("foobar", "integer", None, None, 10, 0, "UNKNOWN")] + ) @with_cursor() def test_description_initial(self, cursor): @@ -139,7 +149,7 @@ def test_description_initial(self, cursor): @with_cursor() def test_description_failed(self, cursor): try: - cursor.execute('blah_blah') + cursor.execute("blah_blah") except DatabaseError: pass self.assertIsNone(cursor.description) @@ -147,8 +157,9 @@ def test_description_failed(self, cursor): @with_cursor() def test_bad_query(self, cursor): def run(): - cursor.execute('SELECT does_not_exist FROM this_really_does_not_exist') + cursor.execute("SELECT does_not_exist FROM this_really_does_not_exist") cursor.fetchone() + self.assertRaises(DatabaseError, run) @with_cursor() @@ -159,63 +170,87 @@ def test_fetch_no_data(self, cursor): @with_cursor() def test_null_param(self, cursor): - cursor.execute('SELECT %(param)s FROM one_row', {'param': None}) + cursor.execute("SELECT %(param)s FROM one_row", {"param": None}) self.assertEqual(cursor.fetchall(), [(None,)]) @with_cursor() def test_no_params(self, cursor): - self.assertRaises(DatabaseError, lambda: cursor.execute( - 'SELECT %(param)s FROM one_row')) - self.assertRaises(KeyError, lambda: cursor.execute( - 'SELECT %(param)s FROM one_row', {'a': 1})) + self.assertRaises( + DatabaseError, lambda: cursor.execute("SELECT %(param)s FROM one_row") + ) + self.assertRaises( + KeyError, lambda: cursor.execute("SELECT %(param)s FROM one_row", {"a": 1}) + ) @with_cursor() def test_contain_special_character_query(self, cursor): - cursor.execute(""" + cursor.execute( + """ SELECT col_string FROM one_row_complex WHERE col_string LIKE '%str%' - """) - self.assertEqual(cursor.fetchall(), [('a string', )]) - cursor.execute(""" + """ + ) + self.assertEqual(cursor.fetchall(), [("a string",)]) + cursor.execute( + """ SELECT col_string FROM one_row_complex WHERE col_string LIKE '%%str%%' - """) - self.assertEqual(cursor.fetchall(), [('a string', )]) - cursor.execute(""" + """ + ) + self.assertEqual(cursor.fetchall(), [("a string",)]) + cursor.execute( + """ SELECT col_string, '%' FROM one_row_complex WHERE col_string LIKE '%str%' - """) - self.assertEqual(cursor.fetchall(), [('a string', '%')]) - cursor.execute(""" + """ + ) + self.assertEqual(cursor.fetchall(), [("a string", "%")]) + cursor.execute( + """ SELECT col_string, '%%' FROM one_row_complex WHERE col_string LIKE '%%str%%' - """) - self.assertEqual(cursor.fetchall(), [('a string', '%%')]) + """ + ) + self.assertEqual(cursor.fetchall(), [("a string", "%%")]) @with_cursor() def test_contain_special_character_query_with_parameter(self, cursor): - self.assertRaises(TypeError, lambda: cursor.execute( - """ + self.assertRaises( + TypeError, + lambda: cursor.execute( + """ SELECT col_string, %(param)s FROM one_row_complex WHERE col_string LIKE '%str%' - """, {'param': 'a string'})) + """, + {"param": "a string"}, + ), + ) cursor.execute( """ SELECT col_string, %(param)s FROM one_row_complex WHERE col_string LIKE '%%str%%' - """, {'param': 'a string'}) - self.assertEqual(cursor.fetchall(), [('a string', 'a string')]) - self.assertRaises(ValueError, lambda: cursor.execute( - """ + """, + {"param": "a string"}, + ) + self.assertEqual(cursor.fetchall(), [("a string", "a string")]) + self.assertRaises( + ValueError, + lambda: cursor.execute( + """ SELECT col_string, '%' FROM one_row_complex WHERE col_string LIKE %(param)s - """, {'param': '%str%'})) + """, + {"param": "%str%"}, + ), + ) cursor.execute( """ SELECT col_string, '%%' FROM one_row_complex WHERE col_string LIKE %(param)s - """, {'param': '%str%'}) - self.assertEqual(cursor.fetchall(), [('a string', '%')]) + """, + {"param": "%str%"}, + ) + self.assertEqual(cursor.fetchall(), [("a string", "%")]) def test_escape(self): bad_str = """`~!@#$%^&*()_+-={}[]|\\;:'",./<>?\n\r\t """ @@ -223,18 +258,20 @@ def test_escape(self): @with_cursor() def run_escape_case(self, cursor, bad_str): - cursor.execute('SELECT %(a)d, %(b)s FROM one_row', {'a': 1, 'b': bad_str}) + cursor.execute("SELECT %(a)d, %(b)s FROM one_row", {"a": 1, "b": bad_str}) self.assertEqual(cursor.fetchall(), [(1, bad_str,)]) @with_cursor() def test_none_empty_query(self, cursor): self.assertRaises(ProgrammingError, lambda: cursor.execute(None)) - self.assertRaises(ProgrammingError, lambda: cursor.execute('')) + self.assertRaises(ProgrammingError, lambda: cursor.execute("")) @with_cursor() def test_invalid_params(self, cursor): - self.assertRaises(TypeError, lambda: cursor.execute( - 'SELECT * FROM one_row', {'foo': {'bar': 1}})) + self.assertRaises( + TypeError, + lambda: cursor.execute("SELECT * FROM one_row", {"foo": {"bar": 1}}), + ) def test_open_close(self): with contextlib.closing(self.connect()): @@ -245,38 +282,42 @@ def test_open_close(self): @with_cursor() def test_unicode(self, cursor): - unicode_str = '王兢' - cursor.execute('SELECT %(param)s FROM one_row', {'param': unicode_str}) + unicode_str = "王兢" + cursor.execute("SELECT %(param)s FROM one_row", {"param": unicode_str}) self.assertEqual(cursor.fetchall(), [(unicode_str,)]) @with_cursor() def test_decimal(self, cursor): - cursor.execute('SELECT %(decimal)s', {'decimal': Decimal('0.00000000001')}) - self.assertEqual(cursor.fetchall(), [(Decimal('0.00000000001'),)]) + cursor.execute("SELECT %(decimal)s", {"decimal": Decimal("0.00000000001")}) + self.assertEqual(cursor.fetchall(), [(Decimal("0.00000000001"),)]) @with_cursor() def test_null(self, cursor): - cursor.execute('SELECT null FROM many_rows') + cursor.execute("SELECT null FROM many_rows") self.assertEqual(cursor.fetchall(), [(None,)] * 10000) - cursor.execute('SELECT IF(a % 11 = 0, null, a) FROM many_rows') - self.assertEqual(cursor.fetchall(), - [(None if a % 11 == 0 else a,) for a in xrange(10000)]) + cursor.execute("SELECT IF(a % 11 = 0, null, a) FROM many_rows") + self.assertEqual( + cursor.fetchall(), [(None if a % 11 == 0 else a,) for a in xrange(10000)] + ) @with_cursor() def test_query_id(self, cursor): self.assertIsNone(cursor.query_id) - cursor.execute('SELECT * from one_row') + cursor.execute("SELECT * from one_row") # query_id is UUID v4 - expected_pattern = \ - r'^[0-9a-f]{8}-[0-9a-f]{4}-[4][0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$' + expected_pattern = ( + r"^[0-9a-f]{8}-[0-9a-f]{4}-[4][0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$" + ) self.assertTrue(re.match(expected_pattern, cursor.query_id)) @with_cursor() def test_output_location(self, cursor): self.assertIsNone(cursor.output_location) - cursor.execute('SELECT * from one_row') - self.assertEqual(cursor.output_location, - '{0}{1}.csv'.format(ENV.s3_staging_dir, cursor.query_id)) + cursor.execute("SELECT * from one_row") + self.assertEqual( + cursor.output_location, + "{0}{1}.csv".format(ENV.s3_staging_dir, cursor.query_id), + ) @with_cursor() def test_query_execution_initial(self, cursor): @@ -294,7 +335,8 @@ def test_query_execution_initial(self, cursor): @with_cursor() def test_complex(self, cursor): - cursor.execute(""" + cursor.execute( + """ SELECT col_boolean ,col_tinyint @@ -315,72 +357,81 @@ def test_complex(self, cursor): ,col_struct ,col_decimal FROM one_row_complex - """) - self.assertEqual(cursor.description, [ - ('col_boolean', 'boolean', None, None, 0, 0, 'UNKNOWN'), - ('col_tinyint', 'tinyint', None, None, 3, 0, 'UNKNOWN'), - ('col_smallint', 'smallint', None, None, 5, 0, 'UNKNOWN'), - ('col_int', 'integer', None, None, 10, 0, 'UNKNOWN'), - ('col_bigint', 'bigint', None, None, 19, 0, 'UNKNOWN'), - ('col_float', 'float', None, None, 17, 0, 'UNKNOWN'), - ('col_double', 'double', None, None, 17, 0, 'UNKNOWN'), - ('col_string', 'varchar', None, None, 2147483647, 0, 'UNKNOWN'), - ('col_timestamp', 'timestamp', None, None, 3, 0, 'UNKNOWN'), - ('col_time', 'time', None, None, 3, 0, 'UNKNOWN'), - ('col_date', 'date', None, None, 0, 0, 'UNKNOWN'), - ('col_binary', 'varbinary', None, None, 1073741824, 0, 'UNKNOWN'), - ('col_array', 'array', None, None, 0, 0, 'UNKNOWN'), - ('col_array_json', 'json', None, None, 0, 0, 'UNKNOWN'), - ('col_map', 'map', None, None, 0, 0, 'UNKNOWN'), - ('col_map_json', 'json', None, None, 0, 0, 'UNKNOWN'), - ('col_struct', 'row', None, None, 0, 0, 'UNKNOWN'), - ('col_decimal', 'decimal', None, None, 10, 1, 'UNKNOWN'), - ]) + """ + ) + self.assertEqual( + cursor.description, + [ + ("col_boolean", "boolean", None, None, 0, 0, "UNKNOWN"), + ("col_tinyint", "tinyint", None, None, 3, 0, "UNKNOWN"), + ("col_smallint", "smallint", None, None, 5, 0, "UNKNOWN"), + ("col_int", "integer", None, None, 10, 0, "UNKNOWN"), + ("col_bigint", "bigint", None, None, 19, 0, "UNKNOWN"), + ("col_float", "float", None, None, 17, 0, "UNKNOWN"), + ("col_double", "double", None, None, 17, 0, "UNKNOWN"), + ("col_string", "varchar", None, None, 2147483647, 0, "UNKNOWN"), + ("col_timestamp", "timestamp", None, None, 3, 0, "UNKNOWN"), + ("col_time", "time", None, None, 3, 0, "UNKNOWN"), + ("col_date", "date", None, None, 0, 0, "UNKNOWN"), + ("col_binary", "varbinary", None, None, 1073741824, 0, "UNKNOWN"), + ("col_array", "array", None, None, 0, 0, "UNKNOWN"), + ("col_array_json", "json", None, None, 0, 0, "UNKNOWN"), + ("col_map", "map", None, None, 0, 0, "UNKNOWN"), + ("col_map_json", "json", None, None, 0, 0, "UNKNOWN"), + ("col_struct", "row", None, None, 0, 0, "UNKNOWN"), + ("col_decimal", "decimal", None, None, 10, 1, "UNKNOWN"), + ], + ) rows = cursor.fetchall() - expected = [( - True, - 127, - 32767, - 2147483647, - 9223372036854775807, - 0.5, - 0.25, - 'a string', - datetime(2017, 1, 1, 0, 0, 0), - datetime(2017, 1, 1, 0, 0, 0).time(), - date(2017, 1, 2), - b'123', - '[1, 2]', - [1, 2], - '{1=2, 3=4}', - {'1': 2, '3': 4}, - '{a=1, b=2}', - Decimal('0.1'), - )] + expected = [ + ( + True, + 127, + 32767, + 2147483647, + 9223372036854775807, + 0.5, + 0.25, + "a string", + datetime(2017, 1, 1, 0, 0, 0), + datetime(2017, 1, 1, 0, 0, 0).time(), + date(2017, 1, 2), + b"123", + "[1, 2]", + [1, 2], + "{1=2, 3=4}", + {"1": 2, "3": 4}, + "{a=1, b=2}", + Decimal("0.1"), + ) + ] self.assertEqual(rows, expected) # catch unicode/str self.assertEqual(list(map(type, rows[0])), list(map(type, expected[0]))) # compare dbapi type object - self.assertEqual([d[1] for d in cursor.description], [ - BOOLEAN, - NUMBER, - NUMBER, - NUMBER, - NUMBER, - NUMBER, - NUMBER, - STRING, - DATETIME, - TIME, - DATE, - BINARY, - STRING, - JSON, - STRING, - JSON, - STRING, - NUMBER, - ]) + self.assertEqual( + [d[1] for d in cursor.description], + [ + BOOLEAN, + NUMBER, + NUMBER, + NUMBER, + NUMBER, + NUMBER, + NUMBER, + STRING, + DATETIME, + TIME, + DATE, + BINARY, + STRING, + JSON, + STRING, + JSON, + STRING, + NUMBER, + ], + ) @with_cursor() def test_cancel(self, cursor): @@ -391,11 +442,16 @@ def cancel(c): with ThreadPoolExecutor(max_workers=1) as executor: executor.submit(cancel, cursor) - self.assertRaises(DatabaseError, lambda: cursor.execute(""" + self.assertRaises( + DatabaseError, + lambda: cursor.execute( + """ SELECT a.a * rand(), b.a * rand() FROM many_rows a CROSS JOIN many_rows b - """)) + """ + ), + ) @with_cursor() def test_cancel_initial(self, cursor): @@ -405,7 +461,7 @@ def test_multiple_connection(self): def execute_other_thread(): with contextlib.closing(connect(schema_name=SCHEMA)) as conn: with conn.cursor() as cursor: - cursor.execute('SELECT * FROM one_row') + cursor.execute("SELECT * FROM one_row") return cursor.fetchall() with ThreadPoolExecutor(max_workers=2) as executor: @@ -418,7 +474,7 @@ def test_no_ops(self): cursor = conn.cursor() self.assertEqual(cursor.rowcount, -1) cursor.setinputsizes([]) - cursor.setoutputsize(1, 'blah') + cursor.setoutputsize(1, "blah") conn.commit() self.assertRaises(NotSupportedError, lambda: conn.rollback()) cursor.close() @@ -426,43 +482,45 @@ def test_no_ops(self): @with_cursor() def test_show_partition(self, cursor): - location = '{0}{1}/{2}/'.format( - ENV.s3_staging_dir, S3_PREFIX, 'partition_table') + location = "{0}{1}/{2}/".format( + ENV.s3_staging_dir, S3_PREFIX, "partition_table" + ) for i in xrange(10): - cursor.execute(""" + cursor.execute( + """ ALTER TABLE partition_table ADD PARTITION (b=%(b)d) LOCATION %(location)s - """, {'b': i, 'location': location}) - cursor.execute('SHOW PARTITIONS partition_table') - self.assertEqual(sorted(cursor.fetchall()), - [('b={0}'.format(i),) for i in xrange(10)]) + """, + {"b": i, "location": location}, + ) + cursor.execute("SHOW PARTITIONS partition_table") + self.assertEqual( + sorted(cursor.fetchall()), [("b={0}".format(i),) for i in xrange(10)] + ) @with_cursor(work_group=WORK_GROUP) def test_workgroup(self, cursor): - cursor.execute('SELECT * FROM one_row') + cursor.execute("SELECT * FROM one_row") self.assertEqual(cursor.work_group, WORK_GROUP) @with_cursor(work_group=WORK_GROUP) def test_no_s3_staging_dir(self, cursor): cursor._s3_staging_dir = None - cursor.execute('SELECT * FROM one_row') + cursor.execute("SELECT * FROM one_row") self.assertNotEqual(cursor.output_location, None) @with_cursor() def test_executemany(self, cursor): cursor.executemany( - 'INSERT INTO execute_many (a) VALUES (%(a)s)', - [{'a': i} for i in xrange(1, 3)] + "INSERT INTO execute_many (a) VALUES (%(a)s)", + [{"a": i} for i in xrange(1, 3)], ) - cursor.execute('SELECT * FROM execute_many') + cursor.execute("SELECT * FROM execute_many") self.assertEqual(sorted(cursor.fetchall()), [(i,) for i in xrange(1, 3)]) @with_cursor() def test_executemany_fetch(self, cursor): - cursor.executemany( - 'SELECT %(x)d FROM one_row', - [{'x': i} for i in range(1, 2)] - ) + cursor.executemany("SELECT %(x)d FROM one_row", [{"x": i} for i in range(1, 2)]) # Operations that have result sets are not allowed with executemany. self.assertRaises(ProgrammingError, cursor.fetchall) self.assertRaises(ProgrammingError, cursor.fetchmany) diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 68ce3a4c..f8b2c17a 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import unittest from datetime import date, datetime @@ -25,10 +24,13 @@ def test_add_partition(self): ADD PARTITION (dt=DATE '2017-01-01', hour=1) """.strip() - actual = self.format(""" + actual = self.format( + """ ALTER TABLE test_table ADD PARTITION (dt=%(dt)s, hour=%(hour)d) - """, {'dt': date(2017, 1, 1), 'hour': 1}) + """, + {"dt": date(2017, 1, 1), "hour": 1}, + ) self.assertEqual(actual, expected) def test_drop_partition(self): @@ -37,10 +39,13 @@ def test_drop_partition(self): DROP PARTITION (dt=DATE '2017-01-01', hour=1) """.strip() - actual = self.format(""" + actual = self.format( + """ ALTER TABLE test_table DROP PARTITION (dt=%(dt)s, hour=%(hour)d) - """, {'dt': date(2017, 1, 1), 'hour': 1}) + """, + {"dt": date(2017, 1, 1), "hour": 1}, + ) self.assertEqual(actual, expected) def test_format_none(self): @@ -50,11 +55,14 @@ def test_format_none(self): WHERE col is null """.strip() - actual = self.format(""" + actual = self.format( + """ SELECT * FROM test_table WHERE col is %(param)s - """, {'param': None}) + """, + {"param": None}, + ) self.assertEqual(actual, expected) def test_format_datetime(self): @@ -65,12 +73,18 @@ def test_format_datetime(self): AND col_timestamp <= TIMESTAMP '2017-01-02 06:00:00.000' """.strip() - actual = self.format(""" + actual = self.format( + """ SELECT * FROM test_table WHERE col_timestamp >= %(start)s AND col_timestamp <= %(end)s - """, {'start': datetime(2017, 1, 1, 12, 0, 0), 'end': datetime(2017, 1, 2, 6, 0, 0)}) + """, + { + "start": datetime(2017, 1, 1, 12, 0, 0), + "end": datetime(2017, 1, 2, 6, 0, 0), + }, + ) self.assertEqual(actual, expected) def test_format_date(self): @@ -80,11 +94,14 @@ def test_format_date(self): WHERE col_date between DATE '2017-01-01' and DATE '2017-01-02' """.strip() - actual = self.format(""" + actual = self.format( + """ SELECT * FROM test_table WHERE col_date between %(start)s and %(end)s - """, {'start': date(2017, 1, 1), 'end': date(2017, 1, 2)}) + """, + {"start": date(2017, 1, 1), "end": date(2017, 1, 2)}, + ) self.assertEqual(actual, expected) def test_format_int(self): @@ -94,11 +111,14 @@ def test_format_int(self): WHERE col_int = 1 """.strip() - actual = self.format(""" + actual = self.format( + """ SELECT * FROM test_table WHERE col_int = %(param)s - """, {'param': 1}) + """, + {"param": 1}, + ) self.assertEqual(actual, expected) def test_format_float(self): @@ -108,11 +128,14 @@ def test_format_float(self): WHERE col_float >= 0.1 """.strip() - actual = self.format(""" + actual = self.format( + """ SELECT * FROM test_table WHERE col_float >= %(param).1f - """, {'param': 0.1}) + """, + {"param": 0.1}, + ) self.assertEqual(actual, expected) def test_format_decimal(self): @@ -122,11 +145,14 @@ def test_format_decimal(self): WHERE col_decimal <= DECIMAL '0.0000000001' """.strip() - actual = self.format(""" + actual = self.format( + """ SELECT * FROM test_table WHERE col_decimal <= %(param)s - """, {'param': Decimal('0.0000000001')}) + """, + {"param": Decimal("0.0000000001")}, + ) self.assertEqual(actual, expected) def test_format_bool(self): @@ -136,11 +162,14 @@ def test_format_bool(self): WHERE col_boolean = True """.strip() - actual = self.format(""" + actual = self.format( + """ SELECT * FROM test_table WHERE col_boolean = %(param)s - """, {'param': True}) + """, + {"param": True}, + ) self.assertEqual(actual, expected) def test_format_str(self): @@ -150,11 +179,14 @@ def test_format_str(self): WHERE col_string = 'amazon athena' """.strip() - actual = self.format(""" + actual = self.format( + """ SELECT * FROM test_table WHERE col_string = %(param)s - """, {'param': 'amazon athena'}) + """, + {"param": "amazon athena"}, + ) self.assertEqual(actual, expected) def test_format_unicode(self): @@ -164,11 +196,14 @@ def test_format_unicode(self): WHERE col_string = '密林 女神' """.strip() - actual = self.format(""" + actual = self.format( + """ SELECT * FROM test_table WHERE col_string = %(param)s - """, {'param': '密林 女神'}) + """, + {"param": "密林 女神"}, + ) self.assertEqual(actual, expected) def test_format_none_list(self): @@ -178,11 +213,14 @@ def test_format_none_list(self): WHERE col IN (null, null) """.strip() - actual = self.format(""" + actual = self.format( + """ SELECT * FROM test_table WHERE col IN %(param)s - """, {'param': [None, None]}) + """, + {"param": [None, None]}, + ) self.assertEqual(actual, expected) def test_format_datetime_list(self): @@ -193,12 +231,15 @@ def test_format_datetime_list(self): (TIMESTAMP '2017-01-01 12:00:00.000', TIMESTAMP '2017-01-02 06:00:00.000') """.strip() - actual = self.format(""" + actual = self.format( + """ SELECT * FROM test_table WHERE col_timestamp IN %(param)s - """, {'param': [datetime(2017, 1, 1, 12, 0, 0), datetime(2017, 1, 2, 6, 0, 0)]}) + """, + {"param": [datetime(2017, 1, 1, 12, 0, 0), datetime(2017, 1, 2, 6, 0, 0)]}, + ) self.assertEqual(actual, expected) def test_format_date_list(self): @@ -208,11 +249,14 @@ def test_format_date_list(self): WHERE col_date IN (DATE '2017-01-01', DATE '2017-01-02') """.strip() - actual = self.format(""" + actual = self.format( + """ SELECT * FROM test_table WHERE col_date IN %(param)s - """, {'param': [date(2017, 1, 1), date(2017, 1, 2)]}) + """, + {"param": [date(2017, 1, 1), date(2017, 1, 2)]}, + ) self.assertEqual(actual, expected) def test_format_int_list(self): @@ -222,11 +266,14 @@ def test_format_int_list(self): WHERE col_int IN (1, 2) """.strip() - actual = self.format(""" + actual = self.format( + """ SELECT * FROM test_table WHERE col_int IN %(param)s - """, {'param': [1, 2]}) + """, + {"param": [1, 2]}, + ) self.assertEqual(actual, expected) def test_format_float_list(self): @@ -237,11 +284,14 @@ def test_format_float_list(self): WHERE col_float IN (0.100000, 0.200000) """.strip() - actual = self.format(""" + actual = self.format( + """ SELECT * FROM test_table WHERE col_float IN %(param)s - """, {'param': [0.1, 0.2]}) + """, + {"param": [0.1, 0.2]}, + ) self.assertEqual(actual, expected) def test_format_decimal_list(self): @@ -251,11 +301,14 @@ def test_format_decimal_list(self): WHERE col_decimal IN (DECIMAL '0.0000000001', DECIMAL '99.9999999999') """.strip() - actual = self.format(""" + actual = self.format( + """ SELECT * FROM test_table WHERE col_decimal IN %(param)s - """, {'param': [Decimal('0.0000000001'), Decimal('99.9999999999')]}) + """, + {"param": [Decimal("0.0000000001"), Decimal("99.9999999999")]}, + ) self.assertEqual(actual, expected) def test_format_bool_list(self): @@ -265,11 +318,14 @@ def test_format_bool_list(self): WHERE col_boolean IN (True, False) """.strip() - actual = self.format(""" + actual = self.format( + """ SELECT * FROM test_table WHERE col_boolean IN %(param)s - """, {'param': [True, False]}) + """, + {"param": [True, False]}, + ) self.assertEqual(actual, expected) def test_format_str_list(self): @@ -279,11 +335,14 @@ def test_format_str_list(self): WHERE col_string IN ('amazon', 'athena') """.strip() - actual = self.format(""" + actual = self.format( + """ SELECT * FROM test_table WHERE col_string IN %(param)s - """, {'param': ['amazon', 'athena']}) + """, + {"param": ["amazon", "athena"]}, + ) self.assertEqual(actual, expected) def test_format_unicode_list(self): @@ -293,28 +352,49 @@ def test_format_unicode_list(self): WHERE col_string IN ('密林', '女神') """.strip() - actual = self.format(""" + actual = self.format( + """ SELECT * FROM test_table WHERE col_string IN %(param)s - """, {'param': ['密林', '女神']}) + """, + {"param": ["密林", "女神"]}, + ) self.assertEqual(actual, expected) def test_format_bad_parameter(self): - self.assertRaises(ProgrammingError, lambda: self.format(""" + self.assertRaises( + ProgrammingError, + lambda: self.format( + """ SELECT * FROM test_table where col_int = $(param)d - """.strip(), 1)) + """.strip(), + 1, + ), + ) - self.assertRaises(ProgrammingError, lambda: self.format(""" + self.assertRaises( + ProgrammingError, + lambda: self.format( + """ SELECT * FROM test_table where col_string = $(param)s - """.strip(), 'a string')) + """.strip(), + "a string", + ), + ) - self.assertRaises(ProgrammingError, lambda: self.format(""" + self.assertRaises( + ProgrammingError, + lambda: self.format( + """ SELECT * FROM test_table where col_string in $(param)s - """.strip(), ['a string'])) + """.strip(), + ["a string"], + ), + ) diff --git a/tests/test_model.py b/tests/test_model.py index cfb326aa..e29dbdb4 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,25 +1,22 @@ # -*- coding: utf-8 -*- -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import unittest from datetime import datetime -from pyathena.model import AthenaQueryExecution, AthenaRowFormat, AthenaCompression +from pyathena.model import AthenaCompression, AthenaQueryExecution, AthenaRowFormat ATHENA_QUERY_EXECUTION_RESPONSE = { "QueryExecution": { "Query": "SELECT * FROM test_table", - "QueryExecutionContext": { - "Database": "test_database" - }, + "QueryExecutionContext": {"Database": "test_database"}, "QueryExecutionId": "12345678-90ab-cdef-1234-567890abcdef", "ResultConfiguration": { "EncryptionConfiguration": { "EncryptionOption": "test_encryption_option", - "KmsKey": "test_kms_key" + "KmsKey": "test_kms_key", }, - "OutputLocation": "s3://bucket/path/to/" + "OutputLocation": "s3://bucket/path/to/", }, "StatementType": "DML", "Statistics": { @@ -32,44 +29,41 @@ "StateChangeReason": "test_reason", "SubmissionDateTime": datetime(2019, 1, 1, 0, 0, 0), }, - "WorkGroup": "test_work_group" + "WorkGroup": "test_work_group", } } class TestAthenaQueryExecution(unittest.TestCase): - def test_init(self): actual = AthenaQueryExecution(ATHENA_QUERY_EXECUTION_RESPONSE) - self.assertEqual(actual.database, 'test_database') - self.assertEqual(actual.query_id, '12345678-90ab-cdef-1234-567890abcdef') - self.assertEqual(actual.query, 'SELECT * FROM test_table') - self.assertEqual(actual.statement_type, 'DML') - self.assertEqual(actual.state, 'SUCCEEDED') - self.assertEqual(actual.state_change_reason, 'test_reason') + self.assertEqual(actual.database, "test_database") + self.assertEqual(actual.query_id, "12345678-90ab-cdef-1234-567890abcdef") + self.assertEqual(actual.query, "SELECT * FROM test_table") + self.assertEqual(actual.statement_type, "DML") + self.assertEqual(actual.state, "SUCCEEDED") + self.assertEqual(actual.state_change_reason, "test_reason") self.assertEqual(actual.completion_date_time, datetime(2019, 1, 1, 0, 0, 0)) self.assertEqual(actual.submission_date_time, datetime(2019, 1, 1, 0, 0, 0)) self.assertEqual(actual.data_scanned_in_bytes, 1234567890) self.assertEqual(actual.execution_time_in_millis, 1234567890) - self.assertEqual(actual.output_location, 's3://bucket/path/to/') - self.assertEqual(actual.encryption_option, 'test_encryption_option') - self.assertEqual(actual.kms_key, 'test_kms_key') - self.assertEqual(actual.work_group, 'test_work_group') + self.assertEqual(actual.output_location, "s3://bucket/path/to/") + self.assertEqual(actual.encryption_option, "test_encryption_option") + self.assertEqual(actual.kms_key, "test_kms_key") + self.assertEqual(actual.work_group, "test_work_group") class TestAthenaRowFormat(unittest.TestCase): - def test_is_valid(self): - self.assertTrue(AthenaRowFormat.is_valid('parquet')) + self.assertTrue(AthenaRowFormat.is_valid("parquet")) self.assertFalse(AthenaRowFormat.is_valid(None)) - self.assertFalse(AthenaRowFormat.is_valid('')) - self.assertFalse(AthenaRowFormat.is_valid('foobar')) + self.assertFalse(AthenaRowFormat.is_valid("")) + self.assertFalse(AthenaRowFormat.is_valid("foobar")) class TestAthenaCompression(unittest.TestCase): - def test_is_valid(self): - self.assertTrue(AthenaCompression.is_valid('snappy')) + self.assertTrue(AthenaCompression.is_valid("snappy")) self.assertFalse(AthenaCompression.is_valid(None)) - self.assertFalse(AthenaCompression.is_valid('')) - self.assertFalse(AthenaCompression.is_valid('foobar')) + self.assertFalse(AthenaCompression.is_valid("")) + self.assertFalse(AthenaCompression.is_valid("foobar")) diff --git a/tests/test_pandas_cursor.py b/tests/test_pandas_cursor.py index 02535185..3641ab03 100644 --- a/tests/test_pandas_cursor.py +++ b/tests/test_pandas_cursor.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import contextlib import random @@ -21,15 +20,14 @@ from pyathena.model import AthenaQueryExecution from pyathena.pandas_cursor import PandasCursor from pyathena.result_set import AthenaPandasResultSet -from tests import WithConnect, SCHEMA, ENV, S3_PREFIX +from tests import ENV, S3_PREFIX, SCHEMA, WithConnect from tests.util import with_pandas_cursor class TestPandasCursor(unittest.TestCase, WithConnect): - @with_pandas_cursor() def test_fetchone(self, cursor): - cursor.execute('SELECT * FROM one_row') + cursor.execute("SELECT * FROM one_row") self.assertEqual(cursor.rownumber, 0) self.assertEqual(cursor.fetchone(), (1,)) self.assertEqual(cursor.rownumber, 1) @@ -37,27 +35,27 @@ def test_fetchone(self, cursor): @with_pandas_cursor() def test_fetchmany(self, cursor): - cursor.execute('SELECT * FROM many_rows LIMIT 15') + cursor.execute("SELECT * FROM many_rows LIMIT 15") self.assertEqual(len(cursor.fetchmany(10)), 10) self.assertEqual(len(cursor.fetchmany(10)), 5) @with_pandas_cursor() def test_fetchall(self, cursor): - cursor.execute('SELECT * FROM one_row') + cursor.execute("SELECT * FROM one_row") self.assertEqual(cursor.fetchall(), [(1,)]) - cursor.execute('SELECT a FROM many_rows ORDER BY a') + cursor.execute("SELECT a FROM many_rows ORDER BY a") self.assertEqual(cursor.fetchall(), [(i,) for i in xrange(10000)]) @with_pandas_cursor() def test_iterator(self, cursor): - cursor.execute('SELECT * FROM one_row') + cursor.execute("SELECT * FROM one_row") self.assertEqual(list(cursor), [(1,)]) self.assertRaises(StopIteration, cursor.__next__) @with_pandas_cursor() def test_arraysize(self, cursor): cursor.arraysize = 5 - cursor.execute('SELECT * FROM many_rows LIMIT 20') + cursor.execute("SELECT * FROM many_rows LIMIT 20") self.assertEqual(len(cursor.fetchmany()), 5) @with_pandas_cursor() @@ -73,7 +71,8 @@ def test_invalid_arraysize(self, cursor): @with_pandas_cursor() def test_complex(self, cursor): - cursor.execute(""" + cursor.execute( + """ SELECT col_boolean ,col_tinyint @@ -94,48 +93,54 @@ def test_complex(self, cursor): ,col_struct ,col_decimal FROM one_row_complex - """) - self.assertEqual(cursor.description, [ - ('col_boolean', 'boolean', None, None, 0, 0, 'UNKNOWN'), - ('col_tinyint', 'tinyint', None, None, 3, 0, 'UNKNOWN'), - ('col_smallint', 'smallint', None, None, 5, 0, 'UNKNOWN'), - ('col_int', 'integer', None, None, 10, 0, 'UNKNOWN'), - ('col_bigint', 'bigint', None, None, 19, 0, 'UNKNOWN'), - ('col_float', 'float', None, None, 17, 0, 'UNKNOWN'), - ('col_double', 'double', None, None, 17, 0, 'UNKNOWN'), - ('col_string', 'varchar', None, None, 2147483647, 0, 'UNKNOWN'), - ('col_timestamp', 'timestamp', None, None, 3, 0, 'UNKNOWN'), - ('col_time', 'time', None, None, 3, 0, 'UNKNOWN'), - ('col_date', 'date', None, None, 0, 0, 'UNKNOWN'), - ('col_binary', 'varbinary', None, None, 1073741824, 0, 'UNKNOWN'), - ('col_array', 'array', None, None, 0, 0, 'UNKNOWN'), - ('col_array_json', 'json', None, None, 0, 0, 'UNKNOWN'), - ('col_map', 'map', None, None, 0, 0, 'UNKNOWN'), - ('col_map_json', 'json', None, None, 0, 0, 'UNKNOWN'), - ('col_struct', 'row', None, None, 0, 0, 'UNKNOWN'), - ('col_decimal', 'decimal', None, None, 10, 1, 'UNKNOWN'), - ]) + """ + ) + self.assertEqual( + cursor.description, + [ + ("col_boolean", "boolean", None, None, 0, 0, "UNKNOWN"), + ("col_tinyint", "tinyint", None, None, 3, 0, "UNKNOWN"), + ("col_smallint", "smallint", None, None, 5, 0, "UNKNOWN"), + ("col_int", "integer", None, None, 10, 0, "UNKNOWN"), + ("col_bigint", "bigint", None, None, 19, 0, "UNKNOWN"), + ("col_float", "float", None, None, 17, 0, "UNKNOWN"), + ("col_double", "double", None, None, 17, 0, "UNKNOWN"), + ("col_string", "varchar", None, None, 2147483647, 0, "UNKNOWN"), + ("col_timestamp", "timestamp", None, None, 3, 0, "UNKNOWN"), + ("col_time", "time", None, None, 3, 0, "UNKNOWN"), + ("col_date", "date", None, None, 0, 0, "UNKNOWN"), + ("col_binary", "varbinary", None, None, 1073741824, 0, "UNKNOWN"), + ("col_array", "array", None, None, 0, 0, "UNKNOWN"), + ("col_array_json", "json", None, None, 0, 0, "UNKNOWN"), + ("col_map", "map", None, None, 0, 0, "UNKNOWN"), + ("col_map_json", "json", None, None, 0, 0, "UNKNOWN"), + ("col_struct", "row", None, None, 0, 0, "UNKNOWN"), + ("col_decimal", "decimal", None, None, 10, 1, "UNKNOWN"), + ], + ) rows = cursor.fetchall() - expected = [( - True, - 127, - 32767, - 2147483647, - 9223372036854775807, - 0.5, - 0.25, - 'a string', - pd.Timestamp(2017, 1, 1, 0, 0, 0), - datetime(2017, 1, 1, 0, 0, 0).time(), - pd.Timestamp(2017, 1, 2), - b'123', - '[1, 2]', - [1, 2], - '{1=2, 3=4}', - {'1': 2, '3': 4}, - '{a=1, b=2}', - Decimal('0.1'), - )] + expected = [ + ( + True, + 127, + 32767, + 2147483647, + 9223372036854775807, + 0.5, + 0.25, + "a string", + pd.Timestamp(2017, 1, 1, 0, 0, 0), + datetime(2017, 1, 1, 0, 0, 0).time(), + pd.Timestamp(2017, 1, 2), + b"123", + "[1, 2]", + [1, 2], + "{1=2, 3=4}", + {"1": 2, "3": 4}, + "{a=1, b=2}", + Decimal("0.1"), + ) + ] self.assertEqual(rows, expected) @with_pandas_cursor() @@ -147,10 +152,10 @@ def test_fetch_no_data(self, cursor): @with_pandas_cursor() def test_as_pandas(self, cursor): - df = cursor.execute('SELECT * FROM one_row').as_pandas() + df = cursor.execute("SELECT * FROM one_row").as_pandas() self.assertEqual(df.shape[0], 1) self.assertEqual(df.shape[1], 1) - self.assertEqual([(row['number_of_rows'],) for _, row in df.iterrows()], [(1,)]) + self.assertEqual([(row["number_of_rows"],) for _, row in df.iterrows()], [(1,)]) self.assertIsNotNone(cursor.query_id) self.assertIsNotNone(cursor.query) self.assertEqual(cursor.state, AthenaQueryExecution.STATE_SUCCEEDED) @@ -165,15 +170,17 @@ def test_as_pandas(self, cursor): @with_pandas_cursor() def test_many_as_pandas(self, cursor): - df = cursor.execute('SELECT * FROM many_rows').as_pandas() + df = cursor.execute("SELECT * FROM many_rows").as_pandas() self.assertEqual(df.shape[0], 10000) self.assertEqual(df.shape[1], 1) - self.assertEqual([(row['a'],) for _, row in df.iterrows()], - [(i,) for i in xrange(10000)]) + self.assertEqual( + [(row["a"],) for _, row in df.iterrows()], [(i,) for i in xrange(10000)] + ) @with_pandas_cursor() def test_complex_as_pandas(self, cursor): - df = cursor.execute(""" + df = cursor.execute( + """ SELECT col_boolean ,col_tinyint @@ -194,89 +201,107 @@ def test_complex_as_pandas(self, cursor): ,col_struct ,col_decimal FROM one_row_complex - """).as_pandas() + """ + ).as_pandas() self.assertEqual(df.shape[0], 1) self.assertEqual(df.shape[1], 18) - dtypes = tuple([ - df['col_boolean'].dtype.type, - df['col_tinyint'].dtype.type, - df['col_smallint'].dtype.type, - df['col_int'].dtype.type, - df['col_bigint'].dtype.type, - df['col_float'].dtype.type, - df['col_double'].dtype.type, - df['col_string'].dtype.type, - df['col_timestamp'].dtype.type, - df['col_time'].dtype.type, - df['col_date'].dtype.type, - df['col_binary'].dtype.type, - df['col_array'].dtype.type, - df['col_array_json'].dtype.type, - df['col_map'].dtype.type, - df['col_map_json'].dtype.type, - df['col_struct'].dtype.type, - df['col_decimal'].dtype.type, - ]) - self.assertEqual(dtypes, tuple([ - np.bool_, - np.int64, - np.int64, - np.int64, - np.int64, - np.float64, - np.float64, - np.object_, - np.datetime64, - np.object_, - np.datetime64, - np.object_, - np.object_, - np.object_, - np.object_, - np.object_, - np.object_, - np.object_, - ])) - rows = [tuple([ - row['col_boolean'], - row['col_tinyint'], - row['col_smallint'], - row['col_int'], - row['col_bigint'], - row['col_float'], - row['col_double'], - row['col_string'], - row['col_timestamp'], - row['col_time'], - row['col_date'], - row['col_binary'], - row['col_array'], - row['col_array_json'], - row['col_map'], - row['col_map_json'], - row['col_struct'], - row['col_decimal'], - ]) for _, row in df.iterrows()] - self.assertEqual(rows, [( - True, - 127, - 32767, - 2147483647, - 9223372036854775807, - 0.5, - 0.25, - 'a string', - pd.Timestamp(2017, 1, 1, 0, 0, 0), - datetime(2017, 1, 1, 0, 0, 0).time(), - pd.Timestamp(2017, 1, 2), - b'123', - '[1, 2]', - [1, 2], - '{1=2, 3=4}', - {'1': 2, '3': 4}, - '{a=1, b=2}', - Decimal('0.1'), - )]) + dtypes = tuple( + [ + df["col_boolean"].dtype.type, + df["col_tinyint"].dtype.type, + df["col_smallint"].dtype.type, + df["col_int"].dtype.type, + df["col_bigint"].dtype.type, + df["col_float"].dtype.type, + df["col_double"].dtype.type, + df["col_string"].dtype.type, + df["col_timestamp"].dtype.type, + df["col_time"].dtype.type, + df["col_date"].dtype.type, + df["col_binary"].dtype.type, + df["col_array"].dtype.type, + df["col_array_json"].dtype.type, + df["col_map"].dtype.type, + df["col_map_json"].dtype.type, + df["col_struct"].dtype.type, + df["col_decimal"].dtype.type, + ] + ) + self.assertEqual( + dtypes, + tuple( + [ + np.bool_, + np.int64, + np.int64, + np.int64, + np.int64, + np.float64, + np.float64, + np.object_, + np.datetime64, + np.object_, + np.datetime64, + np.object_, + np.object_, + np.object_, + np.object_, + np.object_, + np.object_, + np.object_, + ] + ), + ) + rows = [ + tuple( + [ + row["col_boolean"], + row["col_tinyint"], + row["col_smallint"], + row["col_int"], + row["col_bigint"], + row["col_float"], + row["col_double"], + row["col_string"], + row["col_timestamp"], + row["col_time"], + row["col_date"], + row["col_binary"], + row["col_array"], + row["col_array_json"], + row["col_map"], + row["col_map_json"], + row["col_struct"], + row["col_decimal"], + ] + ) + for _, row in df.iterrows() + ] + self.assertEqual( + rows, + [ + ( + True, + 127, + 32767, + 2147483647, + 9223372036854775807, + 0.5, + 0.25, + "a string", + pd.Timestamp(2017, 1, 1, 0, 0, 0), + datetime(2017, 1, 1, 0, 0, 0).time(), + pd.Timestamp(2017, 1, 2), + b"123", + "[1, 2]", + [1, 2], + "{1=2, 3=4}", + {"1": 2, "3": 4}, + "{a=1, b=2}", + Decimal("0.1"), + ) + ], + ) @with_pandas_cursor() def test_cancel(self, cursor): @@ -287,11 +312,16 @@ def cancel(c): with ThreadPoolExecutor(max_workers=1) as executor: executor.submit(cancel, cursor) - self.assertRaises(DatabaseError, lambda: cursor.execute(""" + self.assertRaises( + DatabaseError, + lambda: cursor.execute( + """ SELECT a.a * rand(), b.a * rand() FROM many_rows a CROSS JOIN many_rows b - """)) + """ + ), + ) @with_pandas_cursor() def test_cancel_initial(self, cursor): @@ -310,77 +340,65 @@ def test_no_ops(self): @with_pandas_cursor() def test_show_columns(self, cursor): - cursor.execute('SHOW COLUMNS IN one_row') - self.assertEqual(cursor.fetchall(), [('number_of_rows ',)]) + cursor.execute("SHOW COLUMNS IN one_row") + self.assertEqual(cursor.fetchall(), [("number_of_rows ",)]) @with_pandas_cursor() def test_empty_result(self, cursor): - table = 'test_pandas_cursor_empty_result_' + ''.join([random.choice( - string.ascii_lowercase + string.digits) for _ in xrange(10)]) - location = '{0}{1}/{2}/'.format(ENV.s3_staging_dir, S3_PREFIX, table) - df = cursor.execute(""" + table = "test_pandas_cursor_empty_result_" + "".join( + [random.choice(string.ascii_lowercase + string.digits) for _ in xrange(10)] + ) + location = "{0}{1}/{2}/".format(ENV.s3_staging_dir, S3_PREFIX, table) + df = cursor.execute( + """ CREATE EXTERNAL TABLE IF NOT EXISTS {schema}.{table} (number_of_rows INT) ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' LINES TERMINATED BY '\n' STORED AS TEXTFILE LOCATION '{location}' - """.format(schema=SCHEMA, table=table, location=location)).as_pandas() + """.format( + schema=SCHEMA, table=table, location=location + ) + ).as_pandas() self.assertEqual(df.shape[0], 0) self.assertEqual(df.shape[1], 0) @with_pandas_cursor() def test_integer_na_values(self, cursor): - df = cursor.execute(""" + df = cursor.execute( + """ SELECT * FROM integer_na_values - """).as_pandas() - rows = [tuple([ - row['a'], - row['b'], - ]) for _, row in df.iterrows()] - version = float(re.search(r'^([\d]+\.[\d]+)\..+', pd.__version__).group(1)) + """ + ).as_pandas() + rows = [tuple([row["a"], row["b"],]) for _, row in df.iterrows()] + version = float(re.search(r"^([\d]+\.[\d]+)\..+", pd.__version__).group(1)) if version >= 1.0: - self.assertEqual(rows, [ - (1, 2), - (1, pd.NA), - (pd.NA, pd.NA), - ]) + self.assertEqual(rows, [(1, 2), (1, pd.NA), (pd.NA, pd.NA),]) else: - self.assertEqual(rows, [ - (1, 2), - (1, np.nan), - (np.nan, np.nan), - ]) + self.assertEqual(rows, [(1, 2), (1, np.nan), (np.nan, np.nan),]) @with_pandas_cursor() def test_boolean_na_values(self, cursor): - df = cursor.execute(""" + df = cursor.execute( + """ SELECT * FROM boolean_na_values - """).as_pandas() - rows = [tuple([ - row['a'], - row['b'], - ]) for _, row in df.iterrows()] - self.assertEqual(rows, [ - (True, False), - (False, None), - (None, None), - ]) + """ + ).as_pandas() + rows = [tuple([row["a"], row["b"],]) for _, row in df.iterrows()] + self.assertEqual(rows, [(True, False), (False, None), (None, None),]) @with_pandas_cursor() def test_executemany(self, cursor): cursor.executemany( - 'INSERT INTO execute_many_pandas (a) VALUES (%(a)s)', - [{'a': i} for i in xrange(1, 3)] + "INSERT INTO execute_many_pandas (a) VALUES (%(a)s)", + [{"a": i} for i in xrange(1, 3)], ) - cursor.execute('SELECT * FROM execute_many_pandas') + cursor.execute("SELECT * FROM execute_many_pandas") self.assertEqual(sorted(cursor.fetchall()), [(i,) for i in xrange(1, 3)]) @with_pandas_cursor() def test_executemany_fetch(self, cursor): - cursor.executemany( - 'SELECT %(x)d FROM one_row', - [{'x': i} for i in range(1, 2)] - ) + cursor.executemany("SELECT %(x)d FROM one_row", [{"x": i} for i in range(1, 2)]) # Operations that have result sets are not allowed with executemany. self.assertRaises(ProgrammingError, cursor.fetchall) self.assertRaises(ProgrammingError, cursor.fetchmany) @@ -389,7 +407,9 @@ def test_executemany_fetch(self, cursor): @with_pandas_cursor() def test_not_skip_blank_lines(self, cursor): - cursor.execute(""" + cursor.execute( + """ select * from (values (1), (NULL)) - """) + """ + ) self.assertEqual(len(cursor.fetchall()), 2) diff --git a/tests/test_sqlalchemy_athena.py b/tests/test_sqlalchemy_athena.py index 1f0d542f..d436d5ed 100644 --- a/tests/test_sqlalchemy_athena.py +++ b/tests/test_sqlalchemy_athena.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import re import unittest @@ -15,8 +14,17 @@ from sqlalchemy.exc import NoSuchTableError, OperationalError, ProgrammingError from sqlalchemy.sql import expression from sqlalchemy.sql.schema import Column, MetaData, Table -from sqlalchemy.sql.sqltypes import (BIGINT, BINARY, BOOLEAN, DATE, DECIMAL, - FLOAT, INTEGER, STRINGTYPE, TIMESTAMP) +from sqlalchemy.sql.sqltypes import ( + BIGINT, + BINARY, + BOOLEAN, + DATE, + DECIMAL, + FLOAT, + INTEGER, + STRINGTYPE, + TIMESTAMP, +) from tests.conftest import ENV, SCHEMA from tests.util import with_engine @@ -36,17 +44,23 @@ class TestSQLAlchemyAthena(unittest.TestCase): """ def create_engine(self): - conn_str = 'awsathena+rest://athena.{region_name}.amazonaws.com:443/' + \ - '{schema_name}?s3_staging_dir={s3_staging_dir}&s3_dir={s3_dir}' + \ - '&compression=snappy' - return create_engine(conn_str.format(region_name=ENV.region_name, - schema_name=SCHEMA, - s3_staging_dir=quote_plus(ENV.s3_staging_dir), - s3_dir=quote_plus(ENV.s3_staging_dir))) + conn_str = ( + "awsathena+rest://athena.{region_name}.amazonaws.com:443/" + + "{schema_name}?s3_staging_dir={s3_staging_dir}&s3_dir={s3_dir}" + + "&compression=snappy" + ) + return create_engine( + conn_str.format( + region_name=ENV.region_name, + schema_name=SCHEMA, + s3_staging_dir=quote_plus(ENV.s3_staging_dir), + s3_dir=quote_plus(ENV.s3_staging_dir), + ) + ) @with_engine() def test_basic_query(self, engine, conn): - rows = conn.execute('SELECT * FROM one_row').fetchall() + rows = conn.execute("SELECT * FROM one_row").fetchall() self.assertEqual(len(rows), 1) self.assertEqual(rows[0].number_of_rows, 1) self.assertEqual(len(rows[0]), 1) @@ -55,51 +69,60 @@ def test_basic_query(self, engine, conn): def test_reflect_no_such_table(self, engine, conn): self.assertRaises( NoSuchTableError, - lambda: Table('this_does_not_exist', MetaData(bind=engine), autoload=True)) + lambda: Table("this_does_not_exist", MetaData(bind=engine), autoload=True), + ) self.assertRaises( NoSuchTableError, - lambda: Table('this_does_not_exist', MetaData(bind=engine), - schema='also_does_not_exist', autoload=True)) + lambda: Table( + "this_does_not_exist", + MetaData(bind=engine), + schema="also_does_not_exist", + autoload=True, + ), + ) @with_engine() def test_reflect_table(self, engine, conn): - one_row = Table('one_row', MetaData(bind=engine), autoload=True) + one_row = Table("one_row", MetaData(bind=engine), autoload=True) self.assertEqual(len(one_row.c), 1) self.assertIsNotNone(one_row.c.number_of_rows) @with_engine() def test_reflect_table_with_schema(self, engine, conn): - one_row = Table('one_row', MetaData(bind=engine), - schema=SCHEMA, autoload=True) + one_row = Table("one_row", MetaData(bind=engine), schema=SCHEMA, autoload=True) self.assertEqual(len(one_row.c), 1) self.assertIsNotNone(one_row.c.number_of_rows) @with_engine() def test_reflect_table_include_columns(self, engine, conn): - one_row_complex = Table('one_row_complex', MetaData(bind=engine)) - version = float(re.search(r'^([\d]+\.[\d]+)\..+', sqlalchemy.__version__).group(1)) + one_row_complex = Table("one_row_complex", MetaData(bind=engine)) + version = float( + re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1) + ) if version <= 1.2: - engine.dialect.reflecttable(conn, one_row_complex, - include_columns=['col_int'], - exclude_columns=[]) + engine.dialect.reflecttable( + conn, one_row_complex, include_columns=["col_int"], exclude_columns=[] + ) else: # https://docs.sqlalchemy.org/en/13/changelog/changelog_13.html# # change-64ac776996da1a5c3e3460b4c0f0b257 - engine.dialect.reflecttable(conn, one_row_complex, - include_columns=['col_int'], - exclude_columns=[], - resolve_fks=True) + engine.dialect.reflecttable( + conn, + one_row_complex, + include_columns=["col_int"], + exclude_columns=[], + resolve_fks=True, + ) self.assertEqual(len(one_row_complex.c), 1) self.assertIsNotNone(one_row_complex.c.col_int) self.assertRaises(AttributeError, lambda: one_row_complex.c.col_tinyint) @with_engine() def test_unicode(self, engine, conn): - unicode_str = '密林' - one_row = Table('one_row', MetaData(bind=engine)) + unicode_str = "密林" + one_row = Table("one_row", MetaData(bind=engine)) returned_str = sqlalchemy.select( - [expression.bindparam('あまぞん', unicode_str)], - from_obj=one_row, + [expression.bindparam("あまぞん", unicode_str)], from_obj=one_row, ).scalar() self.assertEqual(returned_str, unicode_str) @@ -108,70 +131,78 @@ def test_reflect_schemas(self, engine, conn): insp = sqlalchemy.inspect(engine) schemas = insp.get_schema_names() self.assertIn(SCHEMA, schemas) - self.assertIn('default', schemas) + self.assertIn("default", schemas) @with_engine() def test_get_table_names(self, engine, conn): meta = MetaData() meta.reflect(bind=engine) print(meta.tables) - self.assertIn('one_row', meta.tables) - self.assertIn('one_row_complex', meta.tables) + self.assertIn("one_row", meta.tables) + self.assertIn("one_row_complex", meta.tables) insp = sqlalchemy.inspect(engine) self.assertIn( - 'many_rows', - insp.get_table_names(schema=SCHEMA), + "many_rows", insp.get_table_names(schema=SCHEMA), ) @with_engine() def test_has_table(self, engine, conn): - self.assertTrue(Table('one_row', MetaData(bind=engine)).exists()) - self.assertFalse(Table('this_table_does_not_exist', MetaData(bind=engine)).exists()) + self.assertTrue(Table("one_row", MetaData(bind=engine)).exists()) + self.assertFalse( + Table("this_table_does_not_exist", MetaData(bind=engine)).exists() + ) @with_engine() def test_get_columns(self, engine, conn): insp = sqlalchemy.inspect(engine) - actual = insp.get_columns(table_name='one_row', schema=SCHEMA)[0] - self.assertEqual(actual['name'], 'number_of_rows') - self.assertTrue(isinstance(actual['type'], INTEGER)) - self.assertTrue(actual['nullable']) - self.assertIsNone(actual['default']) - self.assertEqual(actual['ordinal_position'], 1) - self.assertIsNone(actual['comment']) + actual = insp.get_columns(table_name="one_row", schema=SCHEMA)[0] + self.assertEqual(actual["name"], "number_of_rows") + self.assertTrue(isinstance(actual["type"], INTEGER)) + self.assertTrue(actual["nullable"]) + self.assertIsNone(actual["default"]) + self.assertEqual(actual["ordinal_position"], 1) + self.assertIsNone(actual["comment"]) @with_engine() def test_char_length(self, engine, conn): - one_row_complex = Table('one_row_complex', MetaData(bind=engine), autoload=True) - result = sqlalchemy.select([ - sqlalchemy.func.char_length(one_row_complex.c.col_string) - ]).execute().scalar() - self.assertEqual(result, len('a string')) + one_row_complex = Table("one_row_complex", MetaData(bind=engine), autoload=True) + result = ( + sqlalchemy.select( + [sqlalchemy.func.char_length(one_row_complex.c.col_string)] + ) + .execute() + .scalar() + ) + self.assertEqual(result, len("a string")) @with_engine() def test_reflect_select(self, engine, conn): - one_row_complex = Table('one_row_complex', MetaData(bind=engine), autoload=True) + one_row_complex = Table("one_row_complex", MetaData(bind=engine), autoload=True) self.assertEqual(len(one_row_complex.c), 15) self.assertIsInstance(one_row_complex.c.col_string, Column) rows = one_row_complex.select().execute().fetchall() self.assertEqual(len(rows), 1) - self.assertEqual(list(rows[0]), [ - True, - 127, - 32767, - 2147483647, - 9223372036854775807, - 0.5, - 0.25, - 'a string', - datetime(2017, 1, 1, 0, 0, 0), - date(2017, 1, 2), - b'123', - '[1, 2]', - '{1=2, 3=4}', - '{a=1, b=2}', - Decimal('0.1'), - ]) + self.assertEqual( + list(rows[0]), + [ + True, + 127, + 32767, + 2147483647, + 9223372036854775807, + 0.5, + 0.25, + "a string", + datetime(2017, 1, 1, 0, 0, 0), + date(2017, 1, 2), + b"123", + "[1, 2]", + "{1=2, 3=4}", + "{a=1, b=2}", + Decimal("0.1"), + ], + ) self.assertIsInstance(one_row_complex.c.col_boolean.type, BOOLEAN) self.assertIsInstance(one_row_complex.c.col_tinyint.type, INTEGER) self.assertIsInstance(one_row_complex.c.col_smallint.type, INTEGER) @@ -191,128 +222,196 @@ def test_reflect_select(self, engine, conn): @with_engine() def test_reserved_words(self, engine, conn): """Presto uses double quotes, not backticks""" - fake_table = Table('select', MetaData(bind=engine), Column('current_timestamp', STRINGTYPE)) - query = str(fake_table.select(fake_table.c.current_timestamp == 'a')) + fake_table = Table( + "select", MetaData(bind=engine), Column("current_timestamp", STRINGTYPE) + ) + query = str(fake_table.select(fake_table.c.current_timestamp == "a")) self.assertIn('"select"', query) self.assertIn('"current_timestamp"', query) - self.assertNotIn('`select`', query) - self.assertNotIn('`current_timestamp`', query) + self.assertNotIn("`select`", query) + self.assertNotIn("`current_timestamp`", query) @with_engine() def test_retry_if_data_catalog_exception(self, engine, conn): dialect = engine.dialect - exc = OperationalError('', None, - 'Database does_not_exist not found. Please check your query.') - self.assertFalse(dialect._retry_if_data_catalog_exception( - exc, 'does_not_exist', 'does_not_exist')) - self.assertFalse(dialect._retry_if_data_catalog_exception( - exc, 'does_not_exist', 'this_does_not_exist')) - self.assertTrue(dialect._retry_if_data_catalog_exception( - exc, 'this_does_not_exist', 'does_not_exist')) - self.assertTrue(dialect._retry_if_data_catalog_exception( - exc, 'this_does_not_exist', 'this_does_not_exist')) - - exc = OperationalError('', None, - 'Namespace does_not_exist not found. Please check your query.') - self.assertFalse(dialect._retry_if_data_catalog_exception( - exc, 'does_not_exist', 'does_not_exist')) - self.assertFalse(dialect._retry_if_data_catalog_exception( - exc, 'does_not_exist', 'this_does_not_exist')) - self.assertTrue(dialect._retry_if_data_catalog_exception( - exc, 'this_does_not_exist', 'does_not_exist')) - self.assertTrue(dialect._retry_if_data_catalog_exception( - exc, 'this_does_not_exist', 'this_does_not_exist')) - - exc = OperationalError('', None, - 'Table does_not_exist not found. Please check your query.') - self.assertFalse(dialect._retry_if_data_catalog_exception( - exc, 'does_not_exist', 'does_not_exist')) - self.assertTrue(dialect._retry_if_data_catalog_exception( - exc, 'does_not_exist', 'this_does_not_exist')) - self.assertFalse(dialect._retry_if_data_catalog_exception( - exc, 'this_does_not_exist', 'does_not_exist')) - self.assertTrue(dialect._retry_if_data_catalog_exception( - exc, 'this_does_not_exist', 'this_does_not_exist')) - - exc = OperationalError('', None, - 'foobar.') - self.assertTrue(dialect._retry_if_data_catalog_exception( - exc, 'foobar', 'foobar')) - - exc = ProgrammingError('', None, - 'Database does_not_exist not found. Please check your query.') - self.assertFalse(dialect._retry_if_data_catalog_exception( - exc, 'does_not_exist', 'does_not_exist')) - self.assertFalse(dialect._retry_if_data_catalog_exception( - exc, 'does_not_exist', 'this_does_not_exist')) - self.assertFalse(dialect._retry_if_data_catalog_exception( - exc, 'this_does_not_exist', 'does_not_exist')) - self.assertFalse(dialect._retry_if_data_catalog_exception( - exc, 'this_does_not_exist', 'this_does_not_exist')) + exc = OperationalError( + "", None, "Database does_not_exist not found. Please check your query." + ) + self.assertFalse( + dialect._retry_if_data_catalog_exception( + exc, "does_not_exist", "does_not_exist" + ) + ) + self.assertFalse( + dialect._retry_if_data_catalog_exception( + exc, "does_not_exist", "this_does_not_exist" + ) + ) + self.assertTrue( + dialect._retry_if_data_catalog_exception( + exc, "this_does_not_exist", "does_not_exist" + ) + ) + self.assertTrue( + dialect._retry_if_data_catalog_exception( + exc, "this_does_not_exist", "this_does_not_exist" + ) + ) + + exc = OperationalError( + "", None, "Namespace does_not_exist not found. Please check your query." + ) + self.assertFalse( + dialect._retry_if_data_catalog_exception( + exc, "does_not_exist", "does_not_exist" + ) + ) + self.assertFalse( + dialect._retry_if_data_catalog_exception( + exc, "does_not_exist", "this_does_not_exist" + ) + ) + self.assertTrue( + dialect._retry_if_data_catalog_exception( + exc, "this_does_not_exist", "does_not_exist" + ) + ) + self.assertTrue( + dialect._retry_if_data_catalog_exception( + exc, "this_does_not_exist", "this_does_not_exist" + ) + ) + + exc = OperationalError( + "", None, "Table does_not_exist not found. Please check your query." + ) + self.assertFalse( + dialect._retry_if_data_catalog_exception( + exc, "does_not_exist", "does_not_exist" + ) + ) + self.assertTrue( + dialect._retry_if_data_catalog_exception( + exc, "does_not_exist", "this_does_not_exist" + ) + ) + self.assertFalse( + dialect._retry_if_data_catalog_exception( + exc, "this_does_not_exist", "does_not_exist" + ) + ) + self.assertTrue( + dialect._retry_if_data_catalog_exception( + exc, "this_does_not_exist", "this_does_not_exist" + ) + ) + + exc = OperationalError("", None, "foobar.") + self.assertTrue( + dialect._retry_if_data_catalog_exception(exc, "foobar", "foobar") + ) + + exc = ProgrammingError( + "", None, "Database does_not_exist not found. Please check your query." + ) + self.assertFalse( + dialect._retry_if_data_catalog_exception( + exc, "does_not_exist", "does_not_exist" + ) + ) + self.assertFalse( + dialect._retry_if_data_catalog_exception( + exc, "does_not_exist", "this_does_not_exist" + ) + ) + self.assertFalse( + dialect._retry_if_data_catalog_exception( + exc, "this_does_not_exist", "does_not_exist" + ) + ) + self.assertFalse( + dialect._retry_if_data_catalog_exception( + exc, "this_does_not_exist", "this_does_not_exist" + ) + ) @with_engine() def test_get_column_type(self, engine, conn): dialect = engine.dialect - self.assertEqual(dialect._get_column_type('boolean'), 'boolean') - self.assertEqual(dialect._get_column_type('tinyint'), 'tinyint') - self.assertEqual(dialect._get_column_type('smallint'), 'smallint') - self.assertEqual(dialect._get_column_type('integer'), 'integer') - self.assertEqual(dialect._get_column_type('bigint'), 'bigint') - self.assertEqual(dialect._get_column_type('real'), 'real') - self.assertEqual(dialect._get_column_type('double'), 'double') - self.assertEqual(dialect._get_column_type('varchar'), 'varchar') - self.assertEqual(dialect._get_column_type('timestamp'), 'timestamp') - self.assertEqual(dialect._get_column_type('date'), 'date') - self.assertEqual(dialect._get_column_type('varbinary'), 'varbinary') - self.assertEqual(dialect._get_column_type('array(integer)'), 'array') - self.assertEqual(dialect._get_column_type('map(integer, integer)'), 'map') - self.assertEqual(dialect._get_column_type('row(a integer, b integer)'), 'row') - self.assertEqual(dialect._get_column_type('decimal(10,1)'), 'decimal') + self.assertEqual(dialect._get_column_type("boolean"), "boolean") + self.assertEqual(dialect._get_column_type("tinyint"), "tinyint") + self.assertEqual(dialect._get_column_type("smallint"), "smallint") + self.assertEqual(dialect._get_column_type("integer"), "integer") + self.assertEqual(dialect._get_column_type("bigint"), "bigint") + self.assertEqual(dialect._get_column_type("real"), "real") + self.assertEqual(dialect._get_column_type("double"), "double") + self.assertEqual(dialect._get_column_type("varchar"), "varchar") + self.assertEqual(dialect._get_column_type("timestamp"), "timestamp") + self.assertEqual(dialect._get_column_type("date"), "date") + self.assertEqual(dialect._get_column_type("varbinary"), "varbinary") + self.assertEqual(dialect._get_column_type("array(integer)"), "array") + self.assertEqual(dialect._get_column_type("map(integer, integer)"), "map") + self.assertEqual(dialect._get_column_type("row(a integer, b integer)"), "row") + self.assertEqual(dialect._get_column_type("decimal(10,1)"), "decimal") @with_engine() def test_contain_percents_character_query(self, engine, conn): - query = sqlalchemy.sql.text(""" + query = sqlalchemy.sql.text( + """ SELECT date_parse('20191030', '%Y%m%d') - """) + """ + ) result = engine.execute(query) - self.assertEqual(result.fetchall(), [(datetime(2019, 10, 30), )]) + self.assertEqual(result.fetchall(), [(datetime(2019, 10, 30),)]) @with_engine() def test_query_with_parameter(self, engine, conn): - query = sqlalchemy.sql.text(""" + query = sqlalchemy.sql.text( + """ SELECT :word - """) - result = engine.execute(query, word='cat') - self.assertEqual(result.fetchall(), [('cat', )]) + """ + ) + result = engine.execute(query, word="cat") + self.assertEqual(result.fetchall(), [("cat",)]) @with_engine() def test_contain_percents_character_query_with_parameter(self, engine, conn): - query = sqlalchemy.sql.text(""" + query = sqlalchemy.sql.text( + """ SELECT date_parse('20191030', '%Y%m%d'), :word - """) - result = engine.execute(query, word='cat') - self.assertEqual(result.fetchall(), [(datetime(2019, 10, 30), 'cat')]) + """ + ) + result = engine.execute(query, word="cat") + self.assertEqual(result.fetchall(), [(datetime(2019, 10, 30), "cat")]) - query = sqlalchemy.sql.text(""" + query = sqlalchemy.sql.text( + """ SELECT col_string FROM one_row_complex WHERE col_string LIKE 'a%' OR col_string LIKE :param - """) - result = engine.execute(query, param='b%') - self.assertEqual(result.fetchall(), [('a string', )]) + """ + ) + result = engine.execute(query, param="b%") + self.assertEqual(result.fetchall(), [("a string",)]) @with_engine() def test_nan_checks(self, engine, conn): dialect = engine.dialect self.assertFalse(dialect._is_nan("string")) self.assertFalse(dialect._is_nan(1)) - self.assertTrue(dialect._is_nan(float('nan'))) + self.assertTrue(dialect._is_nan(float("nan"))) @with_engine() def test_to_sql(self, engine, conn): - table_name = 'to_sql_{0}'.format(str(uuid.uuid4()).replace('-', '')) - df = pd.DataFrame({'a': [1, 2, 3, 4, 5]}) - df.to_sql(table_name, engine, schema=SCHEMA, index=False, - if_exists='replace', method='multi') + table_name = "to_sql_{0}".format(str(uuid.uuid4()).replace("-", "")) + df = pd.DataFrame({"a": [1, 2, 3, 4, 5]}) + df.to_sql( + table_name, + engine, + schema=SCHEMA, + index=False, + if_exists="replace", + method="multi", + ) table = Table(table_name, MetaData(bind=engine), autoload=True) rows = table.select().execute().fetchall() diff --git a/tests/test_util.py b/tests/test_util.py index 31cc6d0f..2bbb2a8c 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,11 +1,10 @@ # -*- coding: utf-8 -*- -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import textwrap import unittest import uuid -from datetime import datetime, date +from datetime import date, datetime from decimal import Decimal import numpy as np @@ -13,26 +12,31 @@ from past.builtins import xrange from pyathena import DataError, OperationalError -from pyathena.util import (as_pandas, generate_ddl, to_sql, get_chunks, - parse_output_location, reset_index) -from tests import WithConnect, SCHEMA, ENV, S3_PREFIX +from pyathena.util import ( + as_pandas, + generate_ddl, + get_chunks, + parse_output_location, + reset_index, + to_sql, +) +from tests import ENV, S3_PREFIX, SCHEMA, WithConnect from tests.util import with_cursor class TestUtil(unittest.TestCase, WithConnect): - def test_parse_output_location(self): # valid - actual = parse_output_location('s3://bucket/path/to') - self.assertEqual(actual[0], 'bucket') - self.assertEqual(actual[1], 'path/to') + actual = parse_output_location("s3://bucket/path/to") + self.assertEqual(actual[0], "bucket") + self.assertEqual(actual[1], "path/to") # invalid with self.assertRaises(DataError): - parse_output_location('http://foobar') + parse_output_location("http://foobar") def test_get_chunks(self): - df = pd.DataFrame({'a': [1, 2, 3, 4, 5]}) + df = pd.DataFrame({"a": [1, 2, 3, 4, 5]}) actual1 = get_chunks(df) self.assertEqual([len(a) for a in actual1], [5]) actual2 = get_chunks(df, chunksize=2) @@ -50,21 +54,22 @@ def test_get_chunks(self): list(get_chunks(df, chunksize=-1)) def test_reset_index(self): - df = pd.DataFrame({'a': [1, 2, 3, 4, 5]}) + df = pd.DataFrame({"a": [1, 2, 3, 4, 5]}) reset_index(df) - self.assertEqual(list(df.columns), ['index', 'a']) + self.assertEqual(list(df.columns), ["index", "a"]) - df = pd.DataFrame({'a': [1, 2, 3, 4, 5]}) - reset_index(df, index_label='__index__') - self.assertEqual(list(df.columns), ['__index__', 'a']) + df = pd.DataFrame({"a": [1, 2, 3, 4, 5]}) + reset_index(df, index_label="__index__") + self.assertEqual(list(df.columns), ["__index__", "a"]) - df = pd.DataFrame({'a': [1, 2, 3, 4, 5]}) + df = pd.DataFrame({"a": [1, 2, 3, 4, 5]}) with self.assertRaises(ValueError): - reset_index(df, index_label='a') + reset_index(df, index_label="a") @with_cursor() def test_as_pandas(self, cursor): - cursor.execute(""" + cursor.execute( + """ SELECT col_boolean ,col_tinyint @@ -85,60 +90,67 @@ def test_as_pandas(self, cursor): ,col_struct ,col_decimal FROM one_row_complex - """) + """ + ) df = as_pandas(cursor) - rows = [tuple([ - row['col_boolean'], - row['col_tinyint'], - row['col_smallint'], - row['col_int'], - row['col_bigint'], - row['col_float'], - row['col_double'], - row['col_string'], - row['col_timestamp'], - row['col_time'], - row['col_date'], - row['col_binary'], - row['col_array'], - row['col_array_json'], - row['col_map'], - row['col_map_json'], - row['col_struct'], - row['col_decimal'], - ]) for _, row in df.iterrows()] - expected = [( - True, - 127, - 32767, - 2147483647, - 9223372036854775807, - 0.5, - 0.25, - 'a string', - datetime(2017, 1, 1, 0, 0, 0), - datetime(2017, 1, 1, 0, 0, 0).time(), - date(2017, 1, 2), - b'123', - '[1, 2]', - [1, 2], - '{1=2, 3=4}', - {'1': 2, '3': 4}, - '{a=1, b=2}', - Decimal('0.1'), - )] + rows = [ + tuple( + [ + row["col_boolean"], + row["col_tinyint"], + row["col_smallint"], + row["col_int"], + row["col_bigint"], + row["col_float"], + row["col_double"], + row["col_string"], + row["col_timestamp"], + row["col_time"], + row["col_date"], + row["col_binary"], + row["col_array"], + row["col_array_json"], + row["col_map"], + row["col_map_json"], + row["col_struct"], + row["col_decimal"], + ] + ) + for _, row in df.iterrows() + ] + expected = [ + ( + True, + 127, + 32767, + 2147483647, + 9223372036854775807, + 0.5, + 0.25, + "a string", + datetime(2017, 1, 1, 0, 0, 0), + datetime(2017, 1, 1, 0, 0, 0).time(), + date(2017, 1, 2), + b"123", + "[1, 2]", + [1, 2], + "{1=2, 3=4}", + {"1": 2, "3": 4}, + "{a=1, b=2}", + Decimal("0.1"), + ) + ] self.assertEqual(rows, expected) @with_cursor() def test_as_pandas_integer_na_values(self, cursor): - cursor.execute(""" + cursor.execute( + """ SELECT * FROM integer_na_values - """) + """ + ) df = as_pandas(cursor, coerce_float=True) - rows = [tuple([ - row['a'], - row['b'], - ]) for _, row in df.iterrows()] + rows = [tuple([row["a"], row["b"],]) for _, row in df.iterrows()] # TODO AssertionError: Lists differ: # [(1.0, 2.0), (1.0, nan), (nan, nan)] != [(1.0, 2.0), (1.0, nan), (nan, nan)] # self.assertEqual(rows, [ @@ -146,48 +158,54 @@ def test_as_pandas_integer_na_values(self, cursor): # (1.0, np.nan), # (np.nan, np.nan), # ]) - np.testing.assert_array_equal(rows, [ - (1, 2), - (1, np.nan), - (np.nan, np.nan), - ]) + np.testing.assert_array_equal(rows, [(1, 2), (1, np.nan), (np.nan, np.nan),]) @with_cursor() def test_as_pandas_boolean_na_values(self, cursor): - cursor.execute(""" + cursor.execute( + """ SELECT * FROM boolean_na_values - """) + """ + ) df = as_pandas(cursor) - rows = [tuple([ - row['a'], - row['b'], - ]) for _, row in df.iterrows()] - self.assertEqual(rows, [ - (True, False), - (False, None), - (None, None), - ]) + rows = [tuple([row["a"], row["b"],]) for _, row in df.iterrows()] + self.assertEqual(rows, [(True, False), (False, None), (None, None),]) def test_generate_ddl(self): # TODO Add binary column (After dropping support for Python 2.7) - df = pd.DataFrame({ - 'col_int': np.int32([1]), - 'col_bigint': np.int64([12345]), - 'col_float': np.float32([1.0]), - 'col_double': np.float64([1.2345]), - 'col_string': ['a'], - 'col_boolean': np.bool_([True]), - 'col_timestamp': [datetime(2020, 1, 1, 0, 0, 0)], - 'col_date': [date(2020, 12, 31)], - 'col_timedelta': [np.timedelta64(1, 'D')], - }) + df = pd.DataFrame( + { + "col_int": np.int32([1]), + "col_bigint": np.int64([12345]), + "col_float": np.float32([1.0]), + "col_double": np.float64([1.2345]), + "col_string": ["a"], + "col_boolean": np.bool_([True]), + "col_timestamp": [datetime(2020, 1, 1, 0, 0, 0)], + "col_date": [date(2020, 12, 31)], + "col_timedelta": [np.timedelta64(1, "D")], + } + ) # Explicitly specify column order - df = df[['col_int', 'col_bigint', 'col_float', 'col_double', 'col_string', - 'col_boolean', 'col_timestamp', 'col_date', 'col_timedelta']] + df = df[ + [ + "col_int", + "col_bigint", + "col_float", + "col_double", + "col_string", + "col_boolean", + "col_timestamp", + "col_date", + "col_timedelta", + ] + ] - actual = generate_ddl(df, 'test_table', 's3://bucket/path/to/', 'test_schema') - self.assertEqual(actual.strip(), textwrap.dedent( - """ + actual = generate_ddl(df, "test_table", "s3://bucket/path/to/", "test_schema") + self.assertEqual( + actual.strip(), + textwrap.dedent( + """ CREATE EXTERNAL TABLE IF NOT EXISTS `test_schema`.`test_table` ( `col_int` INT, `col_bigint` BIGINT, @@ -201,13 +219,22 @@ def test_generate_ddl(self): ) STORED AS PARQUET LOCATION 's3://bucket/path/to/' - """).strip()) + """ + ).strip(), + ) # compression - actual = generate_ddl(df, 'test_table', 's3://bucket/path/to/', 'test_schema', - compression='snappy') - self.assertEqual(actual.strip(), textwrap.dedent( - """ + actual = generate_ddl( + df, + "test_table", + "s3://bucket/path/to/", + "test_schema", + compression="snappy", + ) + self.assertEqual( + actual.strip(), + textwrap.dedent( + """ CREATE EXTERNAL TABLE IF NOT EXISTS `test_schema`.`test_table` ( `col_int` INT, `col_bigint` BIGINT, @@ -222,13 +249,22 @@ def test_generate_ddl(self): STORED AS PARQUET LOCATION 's3://bucket/path/to/' TBLPROPERTIES ('parquet.compress'='SNAPPY') - """).strip()) + """ + ).strip(), + ) # partitions - actual = generate_ddl(df, 'test_table', 's3://bucket/path/to/', 'test_schema', - partitions=['col_int']) - self.assertEqual(actual.strip(), textwrap.dedent( - """ + actual = generate_ddl( + df, + "test_table", + "s3://bucket/path/to/", + "test_schema", + partitions=["col_int"], + ) + self.assertEqual( + actual.strip(), + textwrap.dedent( + """ CREATE EXTERNAL TABLE IF NOT EXISTS `test_schema`.`test_table` ( `col_bigint` BIGINT, `col_float` FLOAT, @@ -244,13 +280,22 @@ def test_generate_ddl(self): ) STORED AS PARQUET LOCATION 's3://bucket/path/to/' - """).strip()) + """ + ).strip(), + ) # multiple partitions - actual = generate_ddl(df, 'test_table', 's3://bucket/path/to/', 'test_schema', - partitions=['col_int', 'col_string']) - self.assertEqual(actual.strip(), textwrap.dedent( - """ + actual = generate_ddl( + df, + "test_table", + "s3://bucket/path/to/", + "test_schema", + partitions=["col_int", "col_string"], + ) + self.assertEqual( + actual.strip(), + textwrap.dedent( + """ CREATE EXTERNAL TABLE IF NOT EXISTS `test_schema`.`test_table` ( `col_bigint` BIGINT, `col_float` FLOAT, @@ -266,153 +311,255 @@ def test_generate_ddl(self): ) STORED AS PARQUET LOCATION 's3://bucket/path/to/' - """).strip()) + """ + ).strip(), + ) # complex - df = pd.DataFrame({'col_complex': np.complex_([1.0, 2.0, 3.0, 4.0, 5.0])}) + df = pd.DataFrame({"col_complex": np.complex_([1.0, 2.0, 3.0, 4.0, 5.0])}) with self.assertRaises(ValueError): - generate_ddl(df, 'test_table', 's3://bucket/path/to/') + generate_ddl(df, "test_table", "s3://bucket/path/to/") # time - df = pd.DataFrame({'col_time': [datetime(2020, 1, 1, 0, 0, 0).time()]}, - index=['i']) + df = pd.DataFrame( + {"col_time": [datetime(2020, 1, 1, 0, 0, 0).time()]}, index=["i"] + ) with self.assertRaises(ValueError): - generate_ddl(df, 'test_table', 's3://bucket/path/to/') + generate_ddl(df, "test_table", "s3://bucket/path/to/") @with_cursor() def test_to_sql(self, cursor): # TODO Add binary column (After dropping support for Python 2.7) - df = pd.DataFrame({ - 'col_int': np.int32([1]), - 'col_bigint': np.int64([12345]), - 'col_float': np.float32([1.0]), - 'col_double': np.float64([1.2345]), - 'col_string': ['a'], - 'col_boolean': np.bool_([True]), - 'col_timestamp': [datetime(2020, 1, 1, 0, 0, 0)], - 'col_date': [date(2020, 12, 31)], - }) + df = pd.DataFrame( + { + "col_int": np.int32([1]), + "col_bigint": np.int64([12345]), + "col_float": np.float32([1.0]), + "col_double": np.float64([1.2345]), + "col_string": ["a"], + "col_boolean": np.bool_([True]), + "col_timestamp": [datetime(2020, 1, 1, 0, 0, 0)], + "col_date": [date(2020, 12, 31)], + } + ) # Explicitly specify column order - df = df[['col_int', 'col_bigint', 'col_float', 'col_double', 'col_string', - 'col_boolean', 'col_timestamp', 'col_date']] - table_name = 'to_sql_{0}'.format(str(uuid.uuid4()).replace('-', '')) - location = '{0}{1}/{2}/'.format(ENV.s3_staging_dir, S3_PREFIX, table_name) - to_sql(df, table_name, cursor._connection, location, - schema=SCHEMA, if_exists='fail', compression='snappy') + df = df[ + [ + "col_int", + "col_bigint", + "col_float", + "col_double", + "col_string", + "col_boolean", + "col_timestamp", + "col_date", + ] + ] + table_name = "to_sql_{0}".format(str(uuid.uuid4()).replace("-", "")) + location = "{0}{1}/{2}/".format(ENV.s3_staging_dir, S3_PREFIX, table_name) + to_sql( + df, + table_name, + cursor._connection, + location, + schema=SCHEMA, + if_exists="fail", + compression="snappy", + ) # table already exists with self.assertRaises(OperationalError): - to_sql(df, table_name, cursor._connection, location, - schema=SCHEMA, if_exists='fail', compression='snappy') + to_sql( + df, + table_name, + cursor._connection, + location, + schema=SCHEMA, + if_exists="fail", + compression="snappy", + ) # replace - to_sql(df, table_name, cursor._connection, location, - schema=SCHEMA, if_exists='replace', compression='snappy') + to_sql( + df, + table_name, + cursor._connection, + location, + schema=SCHEMA, + if_exists="replace", + compression="snappy", + ) - cursor.execute('SELECT * FROM {0}'.format(table_name)) - self.assertEqual(cursor.fetchall(), [( - 1, - 12345, - 1.0, - 1.2345, - 'a', - True, - datetime(2020, 1, 1, 0, 0, 0), - date(2020, 12, 31), - )]) - self.assertEqual([(d[0], d[1]) for d in cursor.description], [ - ('col_int', 'integer'), - ('col_bigint', 'bigint'), - ('col_float', 'float'), - ('col_double', 'double'), - ('col_string', 'varchar'), - ('col_boolean', 'boolean'), - ('col_timestamp', 'timestamp'), - ('col_date', 'date'), - ]) + cursor.execute("SELECT * FROM {0}".format(table_name)) + self.assertEqual( + cursor.fetchall(), + [ + ( + 1, + 12345, + 1.0, + 1.2345, + "a", + True, + datetime(2020, 1, 1, 0, 0, 0), + date(2020, 12, 31), + ) + ], + ) + self.assertEqual( + [(d[0], d[1]) for d in cursor.description], + [ + ("col_int", "integer"), + ("col_bigint", "bigint"), + ("col_float", "float"), + ("col_double", "double"), + ("col_string", "varchar"), + ("col_boolean", "boolean"), + ("col_timestamp", "timestamp"), + ("col_date", "date"), + ], + ) # append - to_sql(df, table_name, cursor._connection, location, - schema=SCHEMA, if_exists='append', compression='snappy') - cursor.execute('SELECT * FROM {0}'.format(table_name)) - self.assertEqual(cursor.fetchall(), [( - 1, - 12345, - 1.0, - 1.2345, - 'a', - True, - datetime(2020, 1, 1, 0, 0, 0), - date(2020, 12, 31), - ), ( - 1, - 12345, - 1.0, - 1.2345, - 'a', - True, - datetime(2020, 1, 1, 0, 0, 0), - date(2020, 12, 31), - )]) + to_sql( + df, + table_name, + cursor._connection, + location, + schema=SCHEMA, + if_exists="append", + compression="snappy", + ) + cursor.execute("SELECT * FROM {0}".format(table_name)) + self.assertEqual( + cursor.fetchall(), + [ + ( + 1, + 12345, + 1.0, + 1.2345, + "a", + True, + datetime(2020, 1, 1, 0, 0, 0), + date(2020, 12, 31), + ), + ( + 1, + 12345, + 1.0, + 1.2345, + "a", + True, + datetime(2020, 1, 1, 0, 0, 0), + date(2020, 12, 31), + ), + ], + ) @with_cursor() def test_to_sql_with_index(self, cursor): - df = pd.DataFrame({'col_int': np.int32([1])}) - table_name = 'to_sql_{0}'.format(str(uuid.uuid4()).replace('-', '')) - location = '{0}{1}/{2}/'.format(ENV.s3_staging_dir, S3_PREFIX, table_name) - to_sql(df, table_name, cursor._connection, location, - schema=SCHEMA, if_exists='fail', compression='snappy', - index=True, index_label='col_index') - cursor.execute('SELECT * FROM {0}'.format(table_name)) + df = pd.DataFrame({"col_int": np.int32([1])}) + table_name = "to_sql_{0}".format(str(uuid.uuid4()).replace("-", "")) + location = "{0}{1}/{2}/".format(ENV.s3_staging_dir, S3_PREFIX, table_name) + to_sql( + df, + table_name, + cursor._connection, + location, + schema=SCHEMA, + if_exists="fail", + compression="snappy", + index=True, + index_label="col_index", + ) + cursor.execute("SELECT * FROM {0}".format(table_name)) self.assertEqual(cursor.fetchall(), [(0, 1)]) - self.assertEqual([(d[0], d[1]) for d in cursor.description], [ - ('col_index', 'bigint'), - ('col_int', 'integer'), - ]) + self.assertEqual( + [(d[0], d[1]) for d in cursor.description], + [("col_index", "bigint"), ("col_int", "integer"),], + ) @with_cursor() def test_to_sql_with_partitions(self, cursor): - df = pd.DataFrame({ - 'col_int': np.int32([i for i in xrange(10)]), - 'col_bigint': np.int64([12345 for _ in xrange(10)]), - 'col_string': ['a' for _ in xrange(10)], - }) - table_name = 'to_sql_{0}'.format(str(uuid.uuid4()).replace('-', '')) - location = '{0}{1}/{2}/'.format(ENV.s3_staging_dir, S3_PREFIX, table_name) - to_sql(df, table_name, cursor._connection, location, schema=SCHEMA, - partitions=['col_int'], if_exists='fail', compression='snappy') - cursor.execute('SHOW PARTITIONS {0}'.format(table_name)) - self.assertEqual(sorted(cursor.fetchall()), - [('col_int={0}'.format(i),) for i in xrange(10)]) - cursor.execute('SELECT COUNT(*) FROM {0}'.format(table_name)) - self.assertEqual(cursor.fetchall(), [(10, ), ]) + df = pd.DataFrame( + { + "col_int": np.int32([i for i in xrange(10)]), + "col_bigint": np.int64([12345 for _ in xrange(10)]), + "col_string": ["a" for _ in xrange(10)], + } + ) + table_name = "to_sql_{0}".format(str(uuid.uuid4()).replace("-", "")) + location = "{0}{1}/{2}/".format(ENV.s3_staging_dir, S3_PREFIX, table_name) + to_sql( + df, + table_name, + cursor._connection, + location, + schema=SCHEMA, + partitions=["col_int"], + if_exists="fail", + compression="snappy", + ) + cursor.execute("SHOW PARTITIONS {0}".format(table_name)) + self.assertEqual( + sorted(cursor.fetchall()), [("col_int={0}".format(i),) for i in xrange(10)] + ) + cursor.execute("SELECT COUNT(*) FROM {0}".format(table_name)) + self.assertEqual(cursor.fetchall(), [(10,),]) @with_cursor() def test_to_sql_with_multiple_partitions(self, cursor): - df = pd.DataFrame({ - 'col_int': np.int32([i for i in xrange(10)]), - 'col_bigint': np.int64([12345 for _ in xrange(10)]), - 'col_string': ['a' for _ in xrange(5)] + ['b' for _ in xrange(5)], - }) - table_name = 'to_sql_{0}'.format(str(uuid.uuid4()).replace('-', '')) - location = '{0}{1}/{2}/'.format(ENV.s3_staging_dir, S3_PREFIX, table_name) - to_sql(df, table_name, cursor._connection, location, schema=SCHEMA, - partitions=['col_int', 'col_string'], if_exists='fail', compression='snappy') - cursor.execute('SHOW PARTITIONS {0}'.format(table_name)) - self.assertEqual(sorted(cursor.fetchall()), - [('col_int={0}/col_string=a'.format(i),) for i in xrange(5)] + - [('col_int={0}/col_string=b'.format(i),) for i in xrange(5, 10)]) - cursor.execute('SELECT COUNT(*) FROM {0}'.format(table_name)) - self.assertEqual(cursor.fetchall(), [(10, ), ]) + df = pd.DataFrame( + { + "col_int": np.int32([i for i in xrange(10)]), + "col_bigint": np.int64([12345 for _ in xrange(10)]), + "col_string": ["a" for _ in xrange(5)] + ["b" for _ in xrange(5)], + } + ) + table_name = "to_sql_{0}".format(str(uuid.uuid4()).replace("-", "")) + location = "{0}{1}/{2}/".format(ENV.s3_staging_dir, S3_PREFIX, table_name) + to_sql( + df, + table_name, + cursor._connection, + location, + schema=SCHEMA, + partitions=["col_int", "col_string"], + if_exists="fail", + compression="snappy", + ) + cursor.execute("SHOW PARTITIONS {0}".format(table_name)) + self.assertEqual( + sorted(cursor.fetchall()), + [("col_int={0}/col_string=a".format(i),) for i in xrange(5)] + + [("col_int={0}/col_string=b".format(i),) for i in xrange(5, 10)], + ) + cursor.execute("SELECT COUNT(*) FROM {0}".format(table_name)) + self.assertEqual(cursor.fetchall(), [(10,),]) @with_cursor() def test_to_sql_invalid_args(self, cursor): - df = pd.DataFrame({'col_int': np.int32([1])}) - table_name = 'to_sql_{0}'.format(str(uuid.uuid4()).replace('-', '')) - location = '{0}{1}/{2}/'.format(ENV.s3_staging_dir, S3_PREFIX, table_name) + df = pd.DataFrame({"col_int": np.int32([1])}) + table_name = "to_sql_{0}".format(str(uuid.uuid4()).replace("-", "")) + location = "{0}{1}/{2}/".format(ENV.s3_staging_dir, S3_PREFIX, table_name) # invalid if_exists with self.assertRaises(ValueError): - to_sql(df, table_name, cursor._connection, location, - schema=SCHEMA, if_exists='foobar', compression='snappy') + to_sql( + df, + table_name, + cursor._connection, + location, + schema=SCHEMA, + if_exists="foobar", + compression="snappy", + ) # invalid compression with self.assertRaises(ValueError): - to_sql(df, table_name, cursor._connection, location, - schema=SCHEMA, if_exists='fail', compression='foobar') + to_sql( + df, + table_name, + cursor._connection, + location, + schema=SCHEMA, + if_exists="fail", + compression="foobar", + ) diff --git a/tests/util.py b/tests/util.py index c5faaa4f..d045c053 100644 --- a/tests/util.py +++ b/tests/util.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -from __future__ import absolute_import -from __future__ import unicode_literals +from __future__ import absolute_import, unicode_literals import codecs import contextlib @@ -14,7 +13,9 @@ def wrapped_fn(self, *args, **kwargs): with contextlib.closing(self.connect(work_group=work_group)) as conn: with conn.cursor() as cursor: fn(self, cursor, *args, **kwargs) + return wrapped_fn + return _with_cursor @@ -27,7 +28,9 @@ def wrapped_fn(self, *args, **kwargs): with contextlib.closing(self.connect()) as conn: with conn.cursor(AsyncCursor) as cursor: fn(self, cursor, *args, **kwargs) + return wrapped_fn + return _with_async_cursor @@ -40,7 +43,9 @@ def wrapped_fn(self, *args, **kwargs): with contextlib.closing(self.connect()) as conn: with conn.cursor(PandasCursor) as cursor: fn(self, cursor, *args, **kwargs) + return wrapped_fn + return _with_pandas_cursor @@ -53,7 +58,9 @@ def wrapped_fn(self, *args, **kwargs): with contextlib.closing(self.connect()) as conn: with conn.cursor(AsyncPandasCursor) as cursor: fn(self, cursor, *args, **kwargs) + return wrapped_fn + return _with_async_pandas_cursor @@ -67,11 +74,13 @@ def wrapped_fn(self, *args, **kwargs): fn(self, engine, conn, *args, **kwargs) finally: engine.dispose() + return wrapped_fn + return _with_engine def read_query(path): - with codecs.open(path, 'rb', 'utf-8') as f: + with codecs.open(path, "rb", "utf-8") as f: query = f.read() - return [q.strip() for q in query.split(';') if q and q.strip()] + return [q.strip() for q in query.split(";") if q and q.strip()] From f44e2b8b3b72f24b45a68353ca3ad074e36f1734 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sat, 2 May 2020 16:50:08 +0900 Subject: [PATCH 3/6] Fix E231 missing whitespace after ',' --- pyathena/common.py | 6 ++---- setup.py | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/pyathena/common.py b/pyathena/common.py index 944b2b06..f1bf99d2 100644 --- a/pyathena/common.py +++ b/pyathena/common.py @@ -134,7 +134,7 @@ def _build_start_query_execution_request( ): request = { "QueryString": query, - "QueryExecutionContext": {"Database": self._schema_name,}, + "QueryExecutionContext": {"Database": self._schema_name}, "ResultConfiguration": {}, } if self._s3_staging_dir or s3_staging_dir: @@ -155,9 +155,7 @@ def _build_start_query_execution_request( } if self._kms_key: enc_conf.update({"KmsKey": self._kms_key}) - request["ResultConfiguration"].update( - {"EncryptionConfiguration": enc_conf,} - ) + request["ResultConfiguration"].update({"EncryptionConfiguration": enc_conf}) return request def _build_list_query_executions_request( diff --git a/setup.py b/setup.py index d9220be4..9eeaed4b 100755 --- a/setup.py +++ b/setup.py @@ -23,9 +23,9 @@ author_email="laughingman7743@gmail.com", license="MIT License", packages=find_packages(".", exclude=["tests"]), - package_data={"": ["LICENSE", "*.rst", "Pipfile*"],}, + package_data={"": ["LICENSE", "*.rst", "Pipfile*"]}, include_package_data=True, - data_files=[("", ["LICENSE"] + glob("*.rst") + glob("Pipfile*")),], + data_files=[("", ["LICENSE"] + glob("*.rst") + glob("Pipfile*"))], install_requires=[ "future", 'futures;python_version=="2.7"', From 8002fb769515a68f80d12fca1d600b1b3908d780 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sat, 2 May 2020 16:51:42 +0900 Subject: [PATCH 4/6] Interopability with black: https://github.com/timothycrosley/isort/issues/694#issuecomment-564261886 --- setup.cfg | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/setup.cfg b/setup.cfg index 4948b4c3..f3e981a1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,6 +5,8 @@ universal = 1 flake8-max-line-length = 100 [isort] -line_length = 100 -order_by_type = True -multi_line_output = 4 +multi_line_output=3 +include_trailing_comma=True +force_grid_wrap=0 +use_parentheses=True +line_length=88 From 14e98a48fe5c49434e4324af5c7f4ff27eee7bf4 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sat, 2 May 2020 17:22:33 +0900 Subject: [PATCH 5/6] Fix E231 missing whitespace after ',' --- tests/test_pandas_cursor.py | 10 +++++----- tests/test_util.py | 14 +++++++------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/test_pandas_cursor.py b/tests/test_pandas_cursor.py index 3641ab03..7c115810 100644 --- a/tests/test_pandas_cursor.py +++ b/tests/test_pandas_cursor.py @@ -370,12 +370,12 @@ def test_integer_na_values(self, cursor): SELECT * FROM integer_na_values """ ).as_pandas() - rows = [tuple([row["a"], row["b"],]) for _, row in df.iterrows()] + rows = [tuple([row["a"], row["b"]]) for _, row in df.iterrows()] version = float(re.search(r"^([\d]+\.[\d]+)\..+", pd.__version__).group(1)) if version >= 1.0: - self.assertEqual(rows, [(1, 2), (1, pd.NA), (pd.NA, pd.NA),]) + self.assertEqual(rows, [(1, 2), (1, pd.NA), (pd.NA, pd.NA)]) else: - self.assertEqual(rows, [(1, 2), (1, np.nan), (np.nan, np.nan),]) + self.assertEqual(rows, [(1, 2), (1, np.nan), (np.nan, np.nan)]) @with_pandas_cursor() def test_boolean_na_values(self, cursor): @@ -384,8 +384,8 @@ def test_boolean_na_values(self, cursor): SELECT * FROM boolean_na_values """ ).as_pandas() - rows = [tuple([row["a"], row["b"],]) for _, row in df.iterrows()] - self.assertEqual(rows, [(True, False), (False, None), (None, None),]) + rows = [tuple([row["a"], row["b"]]) for _, row in df.iterrows()] + self.assertEqual(rows, [(True, False), (False, None), (None, None)]) @with_pandas_cursor() def test_executemany(self, cursor): diff --git a/tests/test_util.py b/tests/test_util.py index 2bbb2a8c..854c54be 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -150,7 +150,7 @@ def test_as_pandas_integer_na_values(self, cursor): """ ) df = as_pandas(cursor, coerce_float=True) - rows = [tuple([row["a"], row["b"],]) for _, row in df.iterrows()] + rows = [tuple([row["a"], row["b"]]) for _, row in df.iterrows()] # TODO AssertionError: Lists differ: # [(1.0, 2.0), (1.0, nan), (nan, nan)] != [(1.0, 2.0), (1.0, nan), (nan, nan)] # self.assertEqual(rows, [ @@ -158,7 +158,7 @@ def test_as_pandas_integer_na_values(self, cursor): # (1.0, np.nan), # (np.nan, np.nan), # ]) - np.testing.assert_array_equal(rows, [(1, 2), (1, np.nan), (np.nan, np.nan),]) + np.testing.assert_array_equal(rows, [(1, 2), (1, np.nan), (np.nan, np.nan)]) @with_cursor() def test_as_pandas_boolean_na_values(self, cursor): @@ -168,8 +168,8 @@ def test_as_pandas_boolean_na_values(self, cursor): """ ) df = as_pandas(cursor) - rows = [tuple([row["a"], row["b"],]) for _, row in df.iterrows()] - self.assertEqual(rows, [(True, False), (False, None), (None, None),]) + rows = [tuple([row["a"], row["b"]]) for _, row in df.iterrows()] + self.assertEqual(rows, [(True, False), (False, None), (None, None)]) def test_generate_ddl(self): # TODO Add binary column (After dropping support for Python 2.7) @@ -475,7 +475,7 @@ def test_to_sql_with_index(self, cursor): self.assertEqual(cursor.fetchall(), [(0, 1)]) self.assertEqual( [(d[0], d[1]) for d in cursor.description], - [("col_index", "bigint"), ("col_int", "integer"),], + [("col_index", "bigint"), ("col_int", "integer")], ) @with_cursor() @@ -504,7 +504,7 @@ def test_to_sql_with_partitions(self, cursor): sorted(cursor.fetchall()), [("col_int={0}".format(i),) for i in xrange(10)] ) cursor.execute("SELECT COUNT(*) FROM {0}".format(table_name)) - self.assertEqual(cursor.fetchall(), [(10,),]) + self.assertEqual(cursor.fetchall(), [(10,)]) @with_cursor() def test_to_sql_with_multiple_partitions(self, cursor): @@ -534,7 +534,7 @@ def test_to_sql_with_multiple_partitions(self, cursor): + [("col_int={0}/col_string=b".format(i),) for i in xrange(5, 10)], ) cursor.execute("SELECT COUNT(*) FROM {0}".format(table_name)) - self.assertEqual(cursor.fetchall(), [(10,),]) + self.assertEqual(cursor.fetchall(), [(10,)]) @with_cursor() def test_to_sql_invalid_args(self, cursor): From 6c1ea65eb416b4b389daf1ead903a7acc8dba4d8 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sat, 2 May 2020 17:59:25 +0900 Subject: [PATCH 6/6] Fix indent --- tests/test_async_cursor.py | 8 +- tests/test_async_pandas_cursor.py | 20 +- tests/test_cursor.py | 92 +++--- tests/test_formatter.py | 510 +++++++++++++++++------------- tests/test_pandas_cursor.py | 108 +++---- tests/test_sqlalchemy_athena.py | 18 +- tests/test_util.py | 126 ++++---- 7 files changed, 482 insertions(+), 400 deletions(-) diff --git a/tests/test_async_cursor.py b/tests/test_async_cursor.py index 04b859f0..f61ada67 100644 --- a/tests/test_async_cursor.py +++ b/tests/test_async_cursor.py @@ -178,10 +178,10 @@ def test_bad_query(self, cursor): def test_cancel(self, cursor): query_id, future = cursor.execute( """ - SELECT a.a * rand(), b.a * rand() - FROM many_rows a - CROSS JOIN many_rows b - """ + SELECT a.a * rand(), b.a * rand() + FROM many_rows a + CROSS JOIN many_rows b + """ ) time.sleep(randint(1, 5)) cursor.cancel(query_id) diff --git a/tests/test_async_pandas_cursor.py b/tests/test_async_pandas_cursor.py index 53ecbbf5..b6bdc46d 100644 --- a/tests/test_async_pandas_cursor.py +++ b/tests/test_async_pandas_cursor.py @@ -184,10 +184,10 @@ def test_many_as_pandas(self, cursor): def test_cancel(self, cursor): query_id, future = cursor.execute( """ - SELECT a.a * rand(), b.a * rand() - FROM many_rows a - CROSS JOIN many_rows b - """ + SELECT a.a * rand(), b.a * rand() + FROM many_rows a + CROSS JOIN many_rows b + """ ) time.sleep(randint(1, 5)) cursor.cancel(query_id) @@ -221,12 +221,12 @@ def test_empty_result(self, cursor): location = "{0}{1}/{2}/".format(ENV.s3_staging_dir, S3_PREFIX, table) query_id, future = cursor.execute( """ - CREATE EXTERNAL TABLE IF NOT EXISTS - {schema}.{table} (number_of_rows INT) - ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' - LINES TERMINATED BY '\n' STORED AS TEXTFILE - LOCATION '{location}' - """.format( + CREATE EXTERNAL TABLE IF NOT EXISTS + {schema}.{table} (number_of_rows INT) + ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' + LINES TERMINATED BY '\n' STORED AS TEXTFILE + LOCATION '{location}' + """.format( schema=SCHEMA, table=table, location=location ) ) diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 229417c9..96831135 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -186,30 +186,30 @@ def test_no_params(self, cursor): def test_contain_special_character_query(self, cursor): cursor.execute( """ - SELECT col_string FROM one_row_complex - WHERE col_string LIKE '%str%' - """ + SELECT col_string FROM one_row_complex + WHERE col_string LIKE '%str%' + """ ) self.assertEqual(cursor.fetchall(), [("a string",)]) cursor.execute( """ - SELECT col_string FROM one_row_complex - WHERE col_string LIKE '%%str%%' - """ + SELECT col_string FROM one_row_complex + WHERE col_string LIKE '%%str%%' + """ ) self.assertEqual(cursor.fetchall(), [("a string",)]) cursor.execute( """ - SELECT col_string, '%' FROM one_row_complex - WHERE col_string LIKE '%str%' - """ + SELECT col_string, '%' FROM one_row_complex + WHERE col_string LIKE '%str%' + """ ) self.assertEqual(cursor.fetchall(), [("a string", "%")]) cursor.execute( """ - SELECT col_string, '%%' FROM one_row_complex - WHERE col_string LIKE '%%str%%' - """ + SELECT col_string, '%%' FROM one_row_complex + WHERE col_string LIKE '%%str%%' + """ ) self.assertEqual(cursor.fetchall(), [("a string", "%%")]) @@ -219,9 +219,9 @@ def test_contain_special_character_query_with_parameter(self, cursor): TypeError, lambda: cursor.execute( """ - SELECT col_string, %(param)s FROM one_row_complex - WHERE col_string LIKE '%str%' - """, + SELECT col_string, %(param)s FROM one_row_complex + WHERE col_string LIKE '%str%' + """, {"param": "a string"}, ), ) @@ -237,9 +237,9 @@ def test_contain_special_character_query_with_parameter(self, cursor): ValueError, lambda: cursor.execute( """ - SELECT col_string, '%' FROM one_row_complex - WHERE col_string LIKE %(param)s - """, + SELECT col_string, '%' FROM one_row_complex + WHERE col_string LIKE %(param)s + """, {"param": "%str%"}, ), ) @@ -337,27 +337,27 @@ def test_query_execution_initial(self, cursor): def test_complex(self, cursor): cursor.execute( """ - SELECT - col_boolean - ,col_tinyint - ,col_smallint - ,col_int - ,col_bigint - ,col_float - ,col_double - ,col_string - ,col_timestamp - ,CAST(col_timestamp AS time) AS col_time - ,col_date - ,col_binary - ,col_array - ,CAST(col_array AS json) AS col_array_json - ,col_map - ,CAST(col_map AS json) AS col_map_json - ,col_struct - ,col_decimal - FROM one_row_complex - """ + SELECT + col_boolean + ,col_tinyint + ,col_smallint + ,col_int + ,col_bigint + ,col_float + ,col_double + ,col_string + ,col_timestamp + ,CAST(col_timestamp AS time) AS col_time + ,col_date + ,col_binary + ,col_array + ,CAST(col_array AS json) AS col_array_json + ,col_map + ,CAST(col_map AS json) AS col_map_json + ,col_struct + ,col_decimal + FROM one_row_complex + """ ) self.assertEqual( cursor.description, @@ -446,10 +446,10 @@ def cancel(c): DatabaseError, lambda: cursor.execute( """ - SELECT a.a * rand(), b.a * rand() - FROM many_rows a - CROSS JOIN many_rows b - """ + SELECT a.a * rand(), b.a * rand() + FROM many_rows a + CROSS JOIN many_rows b + """ ), ) @@ -488,9 +488,9 @@ def test_show_partition(self, cursor): for i in xrange(10): cursor.execute( """ - ALTER TABLE partition_table ADD PARTITION (b=%(b)d) - LOCATION %(location)s - """, + ALTER TABLE partition_table ADD PARTITION (b=%(b)d) + LOCATION %(location)s + """, {"b": i, "location": location}, ) cursor.execute("SHOW PARTITIONS partition_table") diff --git a/tests/test_formatter.py b/tests/test_formatter.py index f8b2c17a..590855df 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, unicode_literals +import textwrap import unittest from datetime import date, datetime from decimal import Decimal @@ -13,73 +14,90 @@ class TestDefaultParameterFormatter(unittest.TestCase): # TODO More DDL statement test case & Complex parameter format test case - FORMATTER = DefaultParameterFormatter() + def setUp(self): + self.formatter = DefaultParameterFormatter() def format(self, operation, parameters=None): - return self.FORMATTER.format(operation, parameters) + return self.formatter.format(operation, parameters) def test_add_partition(self): - expected = """ - ALTER TABLE test_table - ADD PARTITION (dt=DATE '2017-01-01', hour=1) - """.strip() + expected = textwrap.dedent( + """ + ALTER TABLE test_table + ADD PARTITION (dt=DATE '2017-01-01', hour=1) + """ + ).strip() actual = self.format( - """ - ALTER TABLE test_table - ADD PARTITION (dt=%(dt)s, hour=%(hour)d) - """, + textwrap.dedent( + """ + ALTER TABLE test_table + ADD PARTITION (dt=%(dt)s, hour=%(hour)d) + """ + ).strip(), {"dt": date(2017, 1, 1), "hour": 1}, ) self.assertEqual(actual, expected) def test_drop_partition(self): - expected = """ - ALTER TABLE test_table - DROP PARTITION (dt=DATE '2017-01-01', hour=1) - """.strip() + expected = textwrap.dedent( + """ + ALTER TABLE test_table + DROP PARTITION (dt=DATE '2017-01-01', hour=1) + """ + ).strip() actual = self.format( - """ - ALTER TABLE test_table - DROP PARTITION (dt=%(dt)s, hour=%(hour)d) - """, + textwrap.dedent( + """ + ALTER TABLE test_table + DROP PARTITION (dt=%(dt)s, hour=%(hour)d) + """ + ).strip(), {"dt": date(2017, 1, 1), "hour": 1}, ) self.assertEqual(actual, expected) def test_format_none(self): - expected = """ - SELECT * - FROM test_table - WHERE col is null - """.strip() + expected = textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col is null + """ + ).strip() actual = self.format( - """ - SELECT * - FROM test_table - WHERE col is %(param)s - """, + textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col is %(param)s + """ + ).strip(), {"param": None}, ) self.assertEqual(actual, expected) def test_format_datetime(self): - expected = """ - SELECT * - FROM test_table - WHERE col_timestamp >= TIMESTAMP '2017-01-01 12:00:00.000' - AND col_timestamp <= TIMESTAMP '2017-01-02 06:00:00.000' - """.strip() + expected = textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_timestamp >= TIMESTAMP '2017-01-01 12:00:00.000' + AND col_timestamp <= TIMESTAMP '2017-01-02 06:00:00.000' + """ + ).strip() actual = self.format( - """ - SELECT * - FROM test_table - WHERE col_timestamp >= %(start)s - AND col_timestamp <= %(end)s - """, + textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_timestamp >= %(start)s + AND col_timestamp <= %(end)s + """ + ).strip(), { "start": datetime(2017, 1, 1, 12, 0, 0), "end": datetime(2017, 1, 2, 6, 0, 0), @@ -88,276 +106,340 @@ def test_format_datetime(self): self.assertEqual(actual, expected) def test_format_date(self): - expected = """ - SELECT * - FROM test_table - WHERE col_date between DATE '2017-01-01' and DATE '2017-01-02' - """.strip() + expected = textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_date between DATE '2017-01-01' and DATE '2017-01-02' + """ + ).strip() actual = self.format( - """ - SELECT * - FROM test_table - WHERE col_date between %(start)s and %(end)s - """, + textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_date between %(start)s and %(end)s + """ + ).strip(), {"start": date(2017, 1, 1), "end": date(2017, 1, 2)}, ) self.assertEqual(actual, expected) def test_format_int(self): - expected = """ - SELECT * - FROM test_table - WHERE col_int = 1 - """.strip() + expected = textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_int = 1 + """ + ).strip() actual = self.format( - """ - SELECT * - FROM test_table - WHERE col_int = %(param)s - """, + textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_int = %(param)s + """ + ).strip(), {"param": 1}, ) self.assertEqual(actual, expected) def test_format_float(self): - expected = """ - SELECT * - FROM test_table - WHERE col_float >= 0.1 - """.strip() + expected = textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_float >= 0.1 + """ + ).strip() actual = self.format( - """ - SELECT * - FROM test_table - WHERE col_float >= %(param).1f - """, + textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_float >= %(param).1f + """ + ).strip(), {"param": 0.1}, ) self.assertEqual(actual, expected) def test_format_decimal(self): - expected = """ - SELECT * - FROM test_table - WHERE col_decimal <= DECIMAL '0.0000000001' - """.strip() + expected = textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_decimal <= DECIMAL '0.0000000001' + """ + ).strip() actual = self.format( - """ - SELECT * - FROM test_table - WHERE col_decimal <= %(param)s - """, + textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_decimal <= %(param)s + """ + ).strip(), {"param": Decimal("0.0000000001")}, ) self.assertEqual(actual, expected) def test_format_bool(self): - expected = """ - SELECT * - FROM test_table - WHERE col_boolean = True - """.strip() + expected = textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_boolean = True + """ + ).strip() actual = self.format( - """ - SELECT * - FROM test_table - WHERE col_boolean = %(param)s - """, + textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_boolean = %(param)s + """ + ).strip(), {"param": True}, ) self.assertEqual(actual, expected) def test_format_str(self): - expected = """ - SELECT * - FROM test_table - WHERE col_string = 'amazon athena' - """.strip() + expected = textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_string = 'amazon athena' + """ + ).strip() actual = self.format( - """ - SELECT * - FROM test_table - WHERE col_string = %(param)s - """, + textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_string = %(param)s + """ + ).strip(), {"param": "amazon athena"}, ) self.assertEqual(actual, expected) def test_format_unicode(self): - expected = """ - SELECT * - FROM test_table - WHERE col_string = '密林 女神' - """.strip() + expected = textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_string = '密林 女神' + """ + ).strip() actual = self.format( - """ - SELECT * - FROM test_table - WHERE col_string = %(param)s - """, + textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_string = %(param)s + """ + ).strip(), {"param": "密林 女神"}, ) self.assertEqual(actual, expected) def test_format_none_list(self): - expected = """ - SELECT * - FROM test_table - WHERE col IN (null, null) - """.strip() + expected = textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col IN (null, null) + """ + ).strip() actual = self.format( - """ - SELECT * - FROM test_table - WHERE col IN %(param)s - """, + textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col IN %(param)s + """ + ).strip(), {"param": [None, None]}, ) self.assertEqual(actual, expected) def test_format_datetime_list(self): - expected = """ - SELECT * - FROM test_table - WHERE col_timestamp IN - (TIMESTAMP '2017-01-01 12:00:00.000', TIMESTAMP '2017-01-02 06:00:00.000') - """.strip() + expected = textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_timestamp IN + (TIMESTAMP '2017-01-01 12:00:00.000', TIMESTAMP '2017-01-02 06:00:00.000') + """ + ).strip() actual = self.format( - """ - SELECT * - FROM test_table - WHERE col_timestamp IN - %(param)s - """, + textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_timestamp IN + %(param)s + """ + ).strip(), {"param": [datetime(2017, 1, 1, 12, 0, 0), datetime(2017, 1, 2, 6, 0, 0)]}, ) self.assertEqual(actual, expected) def test_format_date_list(self): - expected = """ - SELECT * - FROM test_table - WHERE col_date IN (DATE '2017-01-01', DATE '2017-01-02') - """.strip() + expected = textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_date IN (DATE '2017-01-01', DATE '2017-01-02') + """ + ).strip() actual = self.format( - """ - SELECT * - FROM test_table - WHERE col_date IN %(param)s - """, + textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_date IN %(param)s + """ + ).strip(), {"param": [date(2017, 1, 1), date(2017, 1, 2)]}, ) self.assertEqual(actual, expected) def test_format_int_list(self): - expected = """ - SELECT * - FROM test_table - WHERE col_int IN (1, 2) - """.strip() + expected = textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_int IN (1, 2) + """ + ).strip() actual = self.format( - """ - SELECT * - FROM test_table - WHERE col_int IN %(param)s - """, + textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_int IN %(param)s + """ + ).strip(), {"param": [1, 2]}, ) self.assertEqual(actual, expected) def test_format_float_list(self): # default precision is 6 - expected = """ - SELECT * - FROM test_table - WHERE col_float IN (0.100000, 0.200000) - """.strip() + expected = textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_float IN (0.100000, 0.200000) + """ + ).strip() actual = self.format( - """ - SELECT * - FROM test_table - WHERE col_float IN %(param)s - """, + textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_float IN %(param)s + """ + ).strip(), {"param": [0.1, 0.2]}, ) self.assertEqual(actual, expected) def test_format_decimal_list(self): - expected = """ - SELECT * - FROM test_table - WHERE col_decimal IN (DECIMAL '0.0000000001', DECIMAL '99.9999999999') - """.strip() + expected = textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_decimal IN (DECIMAL '0.0000000001', DECIMAL '99.9999999999') + """ + ).strip() actual = self.format( - """ - SELECT * - FROM test_table - WHERE col_decimal IN %(param)s - """, + textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_decimal IN %(param)s + """ + ).strip(), {"param": [Decimal("0.0000000001"), Decimal("99.9999999999")]}, ) self.assertEqual(actual, expected) def test_format_bool_list(self): - expected = """ - SELECT * - FROM test_table - WHERE col_boolean IN (True, False) - """.strip() + expected = textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_boolean IN (True, False) + """ + ).strip() actual = self.format( - """ - SELECT * - FROM test_table - WHERE col_boolean IN %(param)s - """, + textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_boolean IN %(param)s + """ + ).strip(), {"param": [True, False]}, ) self.assertEqual(actual, expected) def test_format_str_list(self): - expected = """ - SELECT * - FROM test_table - WHERE col_string IN ('amazon', 'athena') - """.strip() + expected = textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_string IN ('amazon', 'athena') + """ + ).strip() actual = self.format( - """ - SELECT * - FROM test_table - WHERE col_string IN %(param)s - """, + textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_string IN %(param)s + """ + ).strip(), {"param": ["amazon", "athena"]}, ) self.assertEqual(actual, expected) def test_format_unicode_list(self): - expected = """ - SELECT * - FROM test_table - WHERE col_string IN ('密林', '女神') - """.strip() + expected = textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_string IN ('密林', '女神') + """ + ).strip() actual = self.format( - """ - SELECT * - FROM test_table - WHERE col_string IN %(param)s - """, + textwrap.dedent( + """ + SELECT * + FROM test_table + WHERE col_string IN %(param)s + """ + ).strip(), {"param": ["密林", "女神"]}, ) self.assertEqual(actual, expected) @@ -367,10 +449,10 @@ def test_format_bad_parameter(self): ProgrammingError, lambda: self.format( """ - SELECT * - FROM test_table - where col_int = $(param)d - """.strip(), + SELECT * + FROM test_table + where col_int = $(param)d + """.strip(), 1, ), ) @@ -379,10 +461,10 @@ def test_format_bad_parameter(self): ProgrammingError, lambda: self.format( """ - SELECT * - FROM test_table - where col_string = $(param)s - """.strip(), + SELECT * + FROM test_table + where col_string = $(param)s + """.strip(), "a string", ), ) @@ -391,10 +473,10 @@ def test_format_bad_parameter(self): ProgrammingError, lambda: self.format( """ - SELECT * - FROM test_table - where col_string in $(param)s - """.strip(), + SELECT * + FROM test_table + where col_string in $(param)s + """.strip(), ["a string"], ), ) diff --git a/tests/test_pandas_cursor.py b/tests/test_pandas_cursor.py index 7c115810..d4311521 100644 --- a/tests/test_pandas_cursor.py +++ b/tests/test_pandas_cursor.py @@ -73,27 +73,27 @@ def test_invalid_arraysize(self, cursor): def test_complex(self, cursor): cursor.execute( """ - SELECT - col_boolean - ,col_tinyint - ,col_smallint - ,col_int - ,col_bigint - ,col_float - ,col_double - ,col_string - ,col_timestamp - ,CAST(col_timestamp AS time) AS col_time - ,col_date - ,col_binary - ,col_array - ,CAST(col_array AS json) AS col_array_json - ,col_map - ,CAST(col_map AS json) AS col_map_json - ,col_struct - ,col_decimal - FROM one_row_complex - """ + SELECT + col_boolean + ,col_tinyint + ,col_smallint + ,col_int + ,col_bigint + ,col_float + ,col_double + ,col_string + ,col_timestamp + ,CAST(col_timestamp AS time) AS col_time + ,col_date + ,col_binary + ,col_array + ,CAST(col_array AS json) AS col_array_json + ,col_map + ,CAST(col_map AS json) AS col_map_json + ,col_struct + ,col_decimal + FROM one_row_complex + """ ) self.assertEqual( cursor.description, @@ -181,27 +181,27 @@ def test_many_as_pandas(self, cursor): def test_complex_as_pandas(self, cursor): df = cursor.execute( """ - SELECT - col_boolean - ,col_tinyint - ,col_smallint - ,col_int - ,col_bigint - ,col_float - ,col_double - ,col_string - ,col_timestamp - ,CAST(col_timestamp AS time) AS col_time - ,col_date - ,col_binary - ,col_array - ,CAST(col_array AS json) AS col_array_json - ,col_map - ,CAST(col_map AS json) AS col_map_json - ,col_struct - ,col_decimal - FROM one_row_complex - """ + SELECT + col_boolean + ,col_tinyint + ,col_smallint + ,col_int + ,col_bigint + ,col_float + ,col_double + ,col_string + ,col_timestamp + ,CAST(col_timestamp AS time) AS col_time + ,col_date + ,col_binary + ,col_array + ,CAST(col_array AS json) AS col_array_json + ,col_map + ,CAST(col_map AS json) AS col_map_json + ,col_struct + ,col_decimal + FROM one_row_complex + """ ).as_pandas() self.assertEqual(df.shape[0], 1) self.assertEqual(df.shape[1], 18) @@ -351,12 +351,12 @@ def test_empty_result(self, cursor): location = "{0}{1}/{2}/".format(ENV.s3_staging_dir, S3_PREFIX, table) df = cursor.execute( """ - CREATE EXTERNAL TABLE IF NOT EXISTS - {schema}.{table} (number_of_rows INT) - ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' - LINES TERMINATED BY '\n' STORED AS TEXTFILE - LOCATION '{location}' - """.format( + CREATE EXTERNAL TABLE IF NOT EXISTS + {schema}.{table} (number_of_rows INT) + ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' + LINES TERMINATED BY '\n' STORED AS TEXTFILE + LOCATION '{location}' + """.format( schema=SCHEMA, table=table, location=location ) ).as_pandas() @@ -367,8 +367,8 @@ def test_empty_result(self, cursor): def test_integer_na_values(self, cursor): df = cursor.execute( """ - SELECT * FROM integer_na_values - """ + SELECT * FROM integer_na_values + """ ).as_pandas() rows = [tuple([row["a"], row["b"]]) for _, row in df.iterrows()] version = float(re.search(r"^([\d]+\.[\d]+)\..+", pd.__version__).group(1)) @@ -381,8 +381,8 @@ def test_integer_na_values(self, cursor): def test_boolean_na_values(self, cursor): df = cursor.execute( """ - SELECT * FROM boolean_na_values - """ + SELECT * FROM boolean_na_values + """ ).as_pandas() rows = [tuple([row["a"], row["b"]]) for _, row in df.iterrows()] self.assertEqual(rows, [(True, False), (False, None), (None, None)]) @@ -409,7 +409,7 @@ def test_executemany_fetch(self, cursor): def test_not_skip_blank_lines(self, cursor): cursor.execute( """ - select * from (values (1), (NULL)) - """ + select * from (values (1), (NULL)) + """ ) self.assertEqual(len(cursor.fetchall()), 2) diff --git a/tests/test_sqlalchemy_athena.py b/tests/test_sqlalchemy_athena.py index d436d5ed..b85bf0ef 100644 --- a/tests/test_sqlalchemy_athena.py +++ b/tests/test_sqlalchemy_athena.py @@ -358,8 +358,8 @@ def test_get_column_type(self, engine, conn): def test_contain_percents_character_query(self, engine, conn): query = sqlalchemy.sql.text( """ - SELECT date_parse('20191030', '%Y%m%d') - """ + SELECT date_parse('20191030', '%Y%m%d') + """ ) result = engine.execute(query) self.assertEqual(result.fetchall(), [(datetime(2019, 10, 30),)]) @@ -368,8 +368,8 @@ def test_contain_percents_character_query(self, engine, conn): def test_query_with_parameter(self, engine, conn): query = sqlalchemy.sql.text( """ - SELECT :word - """ + SELECT :word + """ ) result = engine.execute(query, word="cat") self.assertEqual(result.fetchall(), [("cat",)]) @@ -378,17 +378,17 @@ def test_query_with_parameter(self, engine, conn): def test_contain_percents_character_query_with_parameter(self, engine, conn): query = sqlalchemy.sql.text( """ - SELECT date_parse('20191030', '%Y%m%d'), :word - """ + SELECT date_parse('20191030', '%Y%m%d'), :word + """ ) result = engine.execute(query, word="cat") self.assertEqual(result.fetchall(), [(datetime(2019, 10, 30), "cat")]) query = sqlalchemy.sql.text( """ - SELECT col_string FROM one_row_complex - WHERE col_string LIKE 'a%' OR col_string LIKE :param - """ + SELECT col_string FROM one_row_complex + WHERE col_string LIKE 'a%' OR col_string LIKE :param + """ ) result = engine.execute(query, param="b%") self.assertEqual(result.fetchall(), [("a string",)]) diff --git a/tests/test_util.py b/tests/test_util.py index 854c54be..392c404d 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -146,8 +146,8 @@ def test_as_pandas(self, cursor): def test_as_pandas_integer_na_values(self, cursor): cursor.execute( """ - SELECT * FROM integer_na_values - """ + SELECT * FROM integer_na_values + """ ) df = as_pandas(cursor, coerce_float=True) rows = [tuple([row["a"], row["b"]]) for _, row in df.iterrows()] @@ -206,20 +206,20 @@ def test_generate_ddl(self): actual.strip(), textwrap.dedent( """ - CREATE EXTERNAL TABLE IF NOT EXISTS `test_schema`.`test_table` ( - `col_int` INT, - `col_bigint` BIGINT, - `col_float` FLOAT, - `col_double` DOUBLE, - `col_string` STRING, - `col_boolean` BOOLEAN, - `col_timestamp` TIMESTAMP, - `col_date` DATE, - `col_timedelta` BIGINT - ) - STORED AS PARQUET - LOCATION 's3://bucket/path/to/' - """ + CREATE EXTERNAL TABLE IF NOT EXISTS `test_schema`.`test_table` ( + `col_int` INT, + `col_bigint` BIGINT, + `col_float` FLOAT, + `col_double` DOUBLE, + `col_string` STRING, + `col_boolean` BOOLEAN, + `col_timestamp` TIMESTAMP, + `col_date` DATE, + `col_timedelta` BIGINT + ) + STORED AS PARQUET + LOCATION 's3://bucket/path/to/' + """ ).strip(), ) @@ -235,21 +235,21 @@ def test_generate_ddl(self): actual.strip(), textwrap.dedent( """ - CREATE EXTERNAL TABLE IF NOT EXISTS `test_schema`.`test_table` ( - `col_int` INT, - `col_bigint` BIGINT, - `col_float` FLOAT, - `col_double` DOUBLE, - `col_string` STRING, - `col_boolean` BOOLEAN, - `col_timestamp` TIMESTAMP, - `col_date` DATE, - `col_timedelta` BIGINT - ) - STORED AS PARQUET - LOCATION 's3://bucket/path/to/' - TBLPROPERTIES ('parquet.compress'='SNAPPY') - """ + CREATE EXTERNAL TABLE IF NOT EXISTS `test_schema`.`test_table` ( + `col_int` INT, + `col_bigint` BIGINT, + `col_float` FLOAT, + `col_double` DOUBLE, + `col_string` STRING, + `col_boolean` BOOLEAN, + `col_timestamp` TIMESTAMP, + `col_date` DATE, + `col_timedelta` BIGINT + ) + STORED AS PARQUET + LOCATION 's3://bucket/path/to/' + TBLPROPERTIES ('parquet.compress'='SNAPPY') + """ ).strip(), ) @@ -265,22 +265,22 @@ def test_generate_ddl(self): actual.strip(), textwrap.dedent( """ - CREATE EXTERNAL TABLE IF NOT EXISTS `test_schema`.`test_table` ( - `col_bigint` BIGINT, - `col_float` FLOAT, - `col_double` DOUBLE, - `col_string` STRING, - `col_boolean` BOOLEAN, - `col_timestamp` TIMESTAMP, - `col_date` DATE, - `col_timedelta` BIGINT - ) - PARTITIONED BY ( - `col_int` INT - ) - STORED AS PARQUET - LOCATION 's3://bucket/path/to/' - """ + CREATE EXTERNAL TABLE IF NOT EXISTS `test_schema`.`test_table` ( + `col_bigint` BIGINT, + `col_float` FLOAT, + `col_double` DOUBLE, + `col_string` STRING, + `col_boolean` BOOLEAN, + `col_timestamp` TIMESTAMP, + `col_date` DATE, + `col_timedelta` BIGINT + ) + PARTITIONED BY ( + `col_int` INT + ) + STORED AS PARQUET + LOCATION 's3://bucket/path/to/' + """ ).strip(), ) @@ -296,22 +296,22 @@ def test_generate_ddl(self): actual.strip(), textwrap.dedent( """ - CREATE EXTERNAL TABLE IF NOT EXISTS `test_schema`.`test_table` ( - `col_bigint` BIGINT, - `col_float` FLOAT, - `col_double` DOUBLE, - `col_boolean` BOOLEAN, - `col_timestamp` TIMESTAMP, - `col_date` DATE, - `col_timedelta` BIGINT - ) - PARTITIONED BY ( - `col_int` INT, - `col_string` STRING - ) - STORED AS PARQUET - LOCATION 's3://bucket/path/to/' - """ + CREATE EXTERNAL TABLE IF NOT EXISTS `test_schema`.`test_table` ( + `col_bigint` BIGINT, + `col_float` FLOAT, + `col_double` DOUBLE, + `col_boolean` BOOLEAN, + `col_timestamp` TIMESTAMP, + `col_date` DATE, + `col_timedelta` BIGINT + ) + PARTITIONED BY ( + `col_int` INT, + `col_string` STRING + ) + STORED AS PARQUET + LOCATION 's3://bucket/path/to/' + """ ).strip(), )