From 2b7bbdaf8219bd4fb78bc265227ee5bf08aabd51 Mon Sep 17 00:00:00 2001 From: Micheal Gendy Date: Wed, 10 Nov 2021 17:00:33 +0200 Subject: [PATCH] first release --- .editorconfig | 21 + .github/ISSUE_TEMPLATE.md | 15 + .gitignore | 106 +++ .readthedocs.yaml | 30 + .travis.yml | 17 + AUTHORS.rst | 13 + CONTRIBUTING.rst | 129 +++ HISTORY.rst | 8 + LICENSE | 22 + MANIFEST.in | 11 + Makefile | 87 ++ README.rst | 481 +++++++++++ docs/Makefile | 192 +++++ docs/_templates/page.html | 8 + docs/api.rst | 105 +++ docs/authors.rst | 1 + docs/conf.py | 62 ++ docs/contributing.rst | 1 + docs/index.rst | 21 + docs/installation.rst | 51 ++ docs/performance.rst | 20 + docs/readme.rst | 1 + docs/requirements.txt | 10 + drf_turbo/__init__.py | 64 ++ drf_turbo/cython_metaclass.h | 107 +++ drf_turbo/cython_metaclass.pxd | 2 + drf_turbo/cython_metaclass.pyx | 123 +++ drf_turbo/exceptions.pyx | 48 ++ drf_turbo/fields.pxd | 178 ++++ drf_turbo/fields.pyx | 1450 ++++++++++++++++++++++++++++++++ drf_turbo/meta.py | 191 +++++ drf_turbo/openapi.py | 921 ++++++++++++++++++++ drf_turbo/parsers.pyx | 91 ++ drf_turbo/renderers.pyx | 156 ++++ drf_turbo/response.pyx | 73 ++ drf_turbo/serializer.pxd | 31 + drf_turbo/serializer.pyx | 509 +++++++++++ drf_turbo/templates/docs.html | 38 + drf_turbo/utils.pyx | 80 ++ input.txt | 4 + requirements.txt | 7 + runtests.py | 51 ++ setup.cfg | 18 + setup.py | 48 ++ tests/__init__.py | 1 + tests/conftest.py | 60 ++ tests/test_fields.py | 1054 +++++++++++++++++++++++ tests/test_parsers.py | 64 ++ tests/test_renderers.py | 180 ++++ tests/test_response.py | 251 ++++++ tests/test_serializer.py | 442 ++++++++++ tox.ini | 19 + 52 files changed, 7673 insertions(+) create mode 100755 .editorconfig create mode 100755 .github/ISSUE_TEMPLATE.md create mode 100755 .gitignore create mode 100755 .readthedocs.yaml create mode 100755 .travis.yml create mode 100755 AUTHORS.rst create mode 100755 CONTRIBUTING.rst create mode 100755 HISTORY.rst create mode 100755 LICENSE create mode 100755 MANIFEST.in create mode 100755 Makefile create mode 100755 README.rst create mode 100755 docs/Makefile create mode 100755 docs/_templates/page.html create mode 100755 docs/api.rst create mode 100755 docs/authors.rst create mode 100755 docs/conf.py create mode 100755 docs/contributing.rst create mode 100755 docs/index.rst create mode 100755 docs/installation.rst create mode 100755 docs/performance.rst create mode 100755 docs/readme.rst create mode 100755 docs/requirements.txt create mode 100755 drf_turbo/__init__.py create mode 100755 drf_turbo/cython_metaclass.h create mode 100755 drf_turbo/cython_metaclass.pxd create mode 100755 drf_turbo/cython_metaclass.pyx create mode 100755 drf_turbo/exceptions.pyx create mode 100755 drf_turbo/fields.pxd create mode 100755 drf_turbo/fields.pyx create mode 100755 drf_turbo/meta.py create mode 100755 drf_turbo/openapi.py create mode 100755 drf_turbo/parsers.pyx create mode 100755 drf_turbo/renderers.pyx create mode 100755 drf_turbo/response.pyx create mode 100755 drf_turbo/serializer.pxd create mode 100755 drf_turbo/serializer.pyx create mode 100755 drf_turbo/templates/docs.html create mode 100755 drf_turbo/utils.pyx create mode 100644 input.txt create mode 100755 requirements.txt create mode 100755 runtests.py create mode 100755 setup.cfg create mode 100755 setup.py create mode 100755 tests/__init__.py create mode 100755 tests/conftest.py create mode 100755 tests/test_fields.py create mode 100644 tests/test_parsers.py create mode 100644 tests/test_renderers.py create mode 100644 tests/test_response.py create mode 100644 tests/test_serializer.py create mode 100755 tox.ini diff --git a/.editorconfig b/.editorconfig new file mode 100755 index 0000000..d4a2c44 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,21 @@ +# http://editorconfig.org + +root = true + +[*] +indent_style = space +indent_size = 4 +trim_trailing_whitespace = true +insert_final_newline = true +charset = utf-8 +end_of_line = lf + +[*.bat] +indent_style = tab +end_of_line = crlf + +[LICENSE] +insert_final_newline = false + +[Makefile] +indent_style = tab diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md new file mode 100755 index 0000000..d4b1407 --- /dev/null +++ b/.github/ISSUE_TEMPLATE.md @@ -0,0 +1,15 @@ +* drf-turbo version: +* Python version: +* Operating System: + +### Description + +Describe what you were trying to get done. +Tell us what happened, what went wrong, and what you expected to happen. + +### What I Did + +``` +Paste the command(s) you ran and the output. +If there was a crash, please include the traceback here. +``` diff --git a/.gitignore b/.gitignore new file mode 100755 index 0000000..4c915d1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,106 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# dotenv +.env + +# virtualenv +.venv +venv/ +ENV/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# IDE settings +.vscode/ +.idea/ diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100755 index 0000000..b44a946 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,30 @@ +# .readthedocs.yaml +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the version of Python and other tools you might need +build: + os: ubuntu-20.04 + tools: + python: "3.9" + +# Build documentation in the docs/ directory with Sphinx +sphinx: + configuration: docs/conf.py + +# If using Sphinx, optionally build your docs in additional formats such as PDF +# formats: +# - pdf + +# Optionally declare the Python requirements required to build your docs +python: + install: + # install build requirements + - requirements: docs/requirements.txt + # build sktime using `python ./setup.py install --force` + - method: setuptools + path: . + diff --git a/.travis.yml b/.travis.yml new file mode 100755 index 0000000..38e59f5 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,17 @@ +language: python +python: +- 3.8 +- 3.7 +- 3.6 +install: pip install -U tox-travis +script: tox +deploy: + provider: pypi + distributions: sdist bdist_wheel + user: __token__ + password: + secure: sj5UwgJ3bxgQD1bQt4+M07iVmu9pT5bjxocMTs3qcO920xVI3CDGuBmV37uQOCPpZHwDVCmLMpucamZgX27ArG7sm6OJHZGXaqYEEceFLPSB3HW4XK+H6W03IENRgeYwbqOjE+u3YzICdxJ1Ue+e6IHqXp2OpCtxn1uWuQRwCnp9MjzfwqkSgtvldVAGafo8qS9rrSlFqaKkMBS2UZsxP5+oMlHoVeLQqs8Rb6O4NOS+qrTgzffDRaEX/e95reqMHrUqUEpjc3ZAn1ECNHFB7gCfqxqCxH1nHng3fQI7nHHz86Aa6WCfTjJZlolzFg0I2d5GqU8dZyTJ1KBvxZkiuGsW9Upg/sSpgEdUFeaphSNl3Q+VEnkFzdtHFUlQGfHS6il/y4kCeozhkB5FRUqBqfKlkb6LSBkkZxp6FHK+n+gXUyOvCRf92sFboo+21BvyGsbCaxXjDgQlQ8pRKkW+XRcygxNpWdBlT6ng2fKBiRQqf+fr2sOvA3dLn4Q/FDNpd6XiaLUajSKudqe4XbTZcNueJNO5zkrOr3qDekTvFtleMnCesurpl6/OZ+RSSfDKETZMg5GQiX5P2fLsg4/lYciKJADlWQ5SwhAVAosgftY2FC/ronI08vXjL/LWWxOOeOr/GVbxdxQBIYaKAch8+xjJYeyyYZzFnm86RBN+lLc= + on: + tags: true + repo: Mng-dev-ai/drf_turbo + python: 3.8 diff --git a/AUTHORS.rst b/AUTHORS.rst new file mode 100755 index 0000000..f548c2c --- /dev/null +++ b/AUTHORS.rst @@ -0,0 +1,13 @@ +======= +Credits +======= + +Development Lead +---------------- + +* Michael Gendy `@Mng `_ + +Contributors +------------ + +None yet. Why not be the first? diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst new file mode 100755 index 0000000..b40a7c2 --- /dev/null +++ b/CONTRIBUTING.rst @@ -0,0 +1,129 @@ +.. highlight:: shell + +============ +Contributing +============ + +Contributions are welcome, and they are greatly appreciated! Every little bit +helps, and credit will always be given. + +You can contribute in many ways: + +Types of Contributions +---------------------- + +Report Bugs +~~~~~~~~~~~ + +Report bugs at https://github.com/Mng-dev-ai/drf-turbo/issues. + +If you are reporting a bug, please include: + +* Your operating system name and version. +* Any details about your local setup that might be helpful in troubleshooting. +* Detailed steps to reproduce the bug. + +Fix Bugs +~~~~~~~~ + +Look through the GitHub issues for bugs. Anything tagged with "bug" and "help +wanted" is open to whoever wants to implement it. + +Implement Features +~~~~~~~~~~~~~~~~~~ + +Look through the GitHub issues for features. Anything tagged with "enhancement" +and "help wanted" is open to whoever wants to implement it. + +Write Documentation +~~~~~~~~~~~~~~~~~~~ + +drf-turbo could always use more documentation, whether as part of the +official drf-turbo docs, in docstrings, or even on the web in blog posts, +articles, and such. + +Submit Feedback +~~~~~~~~~~~~~~~ + +The best way to send feedback is to file an issue at https://github.com/Mng-dev-ai/drf-turbo/issues. + +If you are proposing a feature: + +* Explain in detail how it would work. +* Keep the scope as narrow as possible, to make it easier to implement. +* Remember that this is a volunteer-driven project, and that contributions + are welcome :) + +Get Started! +------------ + +Ready to contribute? Here's how to set up `drf-turbo` for local development. + +1. Fork the `drf-turbo` repo on GitHub. +2. Clone your fork locally:: + + $ git clone git@github.com:your_name_here/drf-turbo.git + +3. Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, this is how you set up your fork for local development:: + + $ mkvirtualenv drf_turbo + $ cd drf_turbo/ + $ python setup.py develop + +4. Create a branch for local development:: + + $ git checkout -b name-of-your-bugfix-or-feature + + Now you can make your changes locally. + +5. When you're done making changes, check that your changes pass flake8 and the + tests, including testing other Python versions with tox:: + + $ flake8 drf-turbo tests + $ python setup.py test or pytest + $ tox + + To get flake8 and tox, just pip install them into your virtualenv. + +6. Commit your changes and push your branch to GitHub:: + + $ git add . + $ git commit -m "Your detailed description of your changes." + $ git push origin name-of-your-bugfix-or-feature + +7. Submit a pull request through the GitHub website. + +Pull Request Guidelines +----------------------- + +Before you submit a pull request, check that it meets these guidelines: + +1. The pull request should include tests. +2. If the pull request adds functionality, the docs should be updated. Put + your new functionality into a function with a docstring, and add the + feature to the list in README.rst. +3. The pull request should work for Python 3.5, 3.6, 3.7 and 3.8, and for PyPy. Check + https://travis-ci.com/Mng-dev-ai/drf-turbo/pull_requests + and make sure that the tests pass for all supported Python versions. + +Testing +---- + +To run the tests: + + + $ ./runtests.py -q + + +Deploying +--------- + +A reminder for the maintainers on how to deploy. +Make sure all your changes are committed (including an entry in HISTORY.rst). +Then run:: + +$ bump2version patch # possible: major / minor / patch +$ git push +$ git push --tags + +Travis will then deploy to PyPI if tests pass. diff --git a/HISTORY.rst b/HISTORY.rst new file mode 100755 index 0000000..672f7c3 --- /dev/null +++ b/HISTORY.rst @@ -0,0 +1,8 @@ +======= +History +======= + +0.1.0 (2021-11-10) +------------------ + +* First release on PyPI. diff --git a/LICENSE b/LICENSE new file mode 100755 index 0000000..914a0d1 --- /dev/null +++ b/LICENSE @@ -0,0 +1,22 @@ +MIT License + +Copyright (c) 2021, Michael Nagy + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100755 index 0000000..965b2dd --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,11 @@ +include AUTHORS.rst +include CONTRIBUTING.rst +include HISTORY.rst +include LICENSE +include README.rst + +recursive-include tests * +recursive-exclude * __pycache__ +recursive-exclude * *.py[co] + +recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif diff --git a/Makefile b/Makefile new file mode 100755 index 0000000..fc93dfb --- /dev/null +++ b/Makefile @@ -0,0 +1,87 @@ +.PHONY: clean clean-test clean-pyc clean-build docs help +.DEFAULT_GOAL := help + +define BROWSER_PYSCRIPT +import os, webbrowser, sys + +from urllib.request import pathname2url + +webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1]))) +endef +export BROWSER_PYSCRIPT + +define PRINT_HELP_PYSCRIPT +import re, sys + +for line in sys.stdin: + match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line) + if match: + target, help = match.groups() + print("%-20s %s" % (target, help)) +endef +export PRINT_HELP_PYSCRIPT + +BROWSER := python -c "$$BROWSER_PYSCRIPT" + +help: + @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST) + +clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts + +clean-build: ## remove build artifacts + rm -fr build/ + rm -fr dist/ + rm -fr .eggs/ + find . -name '*.egg-info' -exec rm -fr {} + + find . -name '*.egg' -exec rm -f {} + + +clean-pyc: ## remove Python file artifacts + find . -name '*.pyc' -exec rm -f {} + + find . -name '*.pyo' -exec rm -f {} + + find . -name '*~' -exec rm -f {} + + find . -name '__pycache__' -exec rm -fr {} + + +clean-test: ## remove test and coverage artifacts + rm -fr .tox/ + rm -f .coverage + rm -fr htmlcov/ + rm -fr .pytest_cache + +lint/flake8: ## check style with flake8 + flake8 drf_turbo tests + +lint: lint/flake8 ## check style + +test: ## run tests quickly with the default Python + python setup.py test + +test-all: ## run tests on every Python version with tox + tox + +coverage: ## check code coverage quickly with the default Python + coverage run --source drf_turbo setup.py test + coverage report -m + coverage html + $(BROWSER) htmlcov/index.html + +docs: ## generate Sphinx HTML documentation, including API docs + rm -f docs/drf_turbo.rst + rm -f docs/modules.rst + sphinx-apidoc -o docs/ drf_turbo + $(MAKE) -C docs clean + $(MAKE) -C docs html + $(BROWSER) docs/_build/html/index.html + +servedocs: docs ## compile the docs watching for changes + watchmedo shell-command -p '*.rst' -c '$(MAKE) -C docs html' -R -D . + +release: dist ## package and upload a release + twine upload dist/* + +dist: clean ## builds source and wheel package + python setup.py sdist + python setup.py bdist_wheel + ls -l dist + +install: clean ## install the package to the active Python's site-packages + python setup.py install diff --git a/README.rst b/README.rst new file mode 100755 index 0000000..93396d6 --- /dev/null +++ b/README.rst @@ -0,0 +1,481 @@ +========= +drf-turbo +========= + + +.. image:: https://img.shields.io/pypi/v/drf-turbo.svg + :target: https://pypi.python.org/pypi/drf-turbo + +.. image:: https://img.shields.io/travis/Mng-dev-ai/drf-turbo.svg + :target: https://travis-ci.com/Mng-dev-ai/drf-turbo + +.. image:: https://readthedocs.org/projects/drf-turbo/badge/?version=latest + :target: https://drf-turbo.readthedocs.io/en/latest/?version=latest + :alt: Documentation Status + + +.. image:: https://pyup.io/repos/github/Mng-dev-ai/drf-turbo/shield.svg + :target: https://pyup.io/repos/github/Mng-dev-ai/drf-turbo/ + :alt: Updates + + + +An alternative serializer implementation for REST framework written in cython built for speed. + + +* Free software: MIT license +* Documentation: https://drf-turbo.readthedocs.io. + + +**NOTE**: Cython is required to build this package. + + +Requirements +------------ + +* Django + +* Django REST Framework + +* Cython + +* forbiddenfruit + +* psycopg2-binary + +* pyyaml(OpenAPI) + +* uritemplate(OpenAPI) + +* djangorestframework-simplejwt(OpenAPI) + + +Installation +------------ + +.. code-block:: console + + $ pip install drf-turbo + + +Examples +======== + +Declaring Serializers +--------------------- +.. code-block:: python + + from datetime import datetime + from django.utils.timezone import now + import drf_turbo as dt + + class User: + def __init__(self, username, email,created=None): + self.username = username + self.email = email + self.created = created or datetime.now() + + user = User(username='test' , email='test@example.com') + + + + class UserSerializer(dt.Serializer): + username = dt.StrField(max_length=50) + email = dt.EmailField() + created = dt.DateTimeField() + + +Serializing objects +------------------- + +.. code-block:: python + + + serializer = UserSerializer(user) + serializer.data + + # {'username': 'test', 'email': 'test@example.com', 'created': '2021-11-04T22:49:01.981127Z'} + + +Deserializing objects +--------------------- + +.. code-block:: python + + data = {'username':'new_test','email':'test2@example.com','created':now()} + serializer = UserSerializer(data=data) + serializer.is_valid() + # True + serializer.validated_data + # {'username': 'new_test', 'email': 'test2@example.com', 'created': '2021-11-04T23:29:13.191304Z'} + +Validation +---------- + +.. code-block:: python + + serializer = UserSerializer(data={'email': 'test'}) + serializer.is_valid() + # False + serializer.errors + # {'username': ['This field is required.'], 'email': ['Enter a valid email address.']} + + +Field-level validation +---------------------- + +.. code-block:: python + + import drf_turbo as dt + + class UserSerializer(dt.Serializer): + username = dt.StrField(max_length=50) + + def validate_username(self, value): + if 'test' not in value.lower(): + raise dt.ValidationError("test must be in username") + return value + +Object-level validation +----------------------- + +.. code-block:: python + + import drf_turbo as dt + + class CampaignSerializer(dt.Serializer): + start_date = dt.DateTimeField() + end_date = dt.DateTimeField() + + def validate(self, data): + if data['start_date'] > data['end_date']: + raise dt.ValidationError("start_date must occur before end_date") + return data + +Nested Serializers +------------------ +.. code-block:: python + + from datetime import datetime + from django.utils.timezone import now + import drf_turbo as dt + + class User: + def __init__(self, username, email,created=None): + self.username = username + self.email = email + self.created = created or datetime.now() + + user = User(username='test' , email='test@example.com') + + class UserSerializer(dt.Serializer): + username = dt.StrField(max_length=50) + email = dt.EmailField() + created = dt.DateTimeField() + + class Profile : + def __init__(self, age=25): + self.user = user + + profile = Profile() + + + class ProfileSerializer(dt.Serializer): + age = dt.IntField() + user = UserSerializer() + + + serializer = ProfileSerializer(profile) + serializer.data + + # {'age' : 25 , 'user' : {'username': 'test', 'email': 'test@example.com', 'created': '2021-11-04T22:49:01.981127Z'}} + + +Filtering Output +---------------- + +drf-turbo provides option to enclude or exclude fields from serializer using ``only`` or ``exclude`` keywords. + +.. code-block:: python + + serializer = UserSerializer(only=('id','username')) + + or + + serializer = ProfileSerializer(exclude=('id','user__email')) + + or + + http://127.0.0.1:8000/?only=id,username + + +Required Fields +--------------- + +Make a field required by passing required=True. An error will be raised if the the value is missing from data during Deserializing. + +For example: + +.. code-block:: python + + class UserSerializer(dt.Serializer): + + username = dt.StrField(required=True,error_messages={"required":"no username"}) + + + +Specifying Defaults +------------------- + +It will be used for the field if no input value is supplied. + + +For example: + +.. code-block:: python + + from datetime import datetime + + class UserSerializer(dt.Serializer): + + birthdate = dt.DateTimeField(default=datetime(2021, 11, 05)) + + + + +ModelSerializer +--------------- + +Mapping serializer to Django model definitions. + +Features : + + * It will automatically generate a set of fields for you, based on the model. + * It will automatically generate validators for the serializer. + * It includes simple default implementations of .create() and .update(). + +.. code-block:: python + + class UserSerializer(dt.ModelSerializer): + + class Meta : + model = User + fields = ('id','username','email') + +You can also set the fields attribute to the special value ``__all__`` to indicate that all fields in the model should be used. + +For example: + +.. code-block:: python + + class UserSerializer(dt.ModelSerializer): + + class Meta : + model = User + fields = '__all__' + +You can set the exclude attribute to a list of fields to be excluded from the serializer. + +For example: + +.. code-block:: python + + class UserSerializer(dt.ModelSerializer): + + class Meta : + model = User + exclude = ('email',) + + +Read&Write only fields +---------------------- + +.. code-block:: python + + class UserSerializer(dt.ModelSerializer): + class Meta: + model = User + fields = ('id', 'username', 'password','password_confirmation') + read_only_fields = ('username') + write_only_fields = ('password','password_confirmation') + +Parsers +------- + +Allow only requests with JSON content, instead of the default of JSON or form data. + +.. code:: python + + REST_FRAMEWORK = { + 'DEFAULT_PARSER_CLASSES': [ + 'drf_turbo.parsers.JSONParser', + ] + } + + or + + REST_FRAMEWORK = { + 'DEFAULT_PARSER_CLASSES': [ + 'drf_turbo.parsers.UJSONParser', + ] + } + + or + + REST_FRAMEWORK = { + 'DEFAULT_PARSER_CLASSES': [ + 'drf_turbo.parsers.ORJSONParser', + ] + } + +**NOTE**: ujson must be installed to use UJSONParser. + +**NOTE**: orjson must be installed to use ORJSONParser. + + + +Renderers +--------- + +Use JSON as the main media type. + +.. code:: python + + + REST_FRAMEWORK = { + 'DEFAULT_RENDERERS_CLASSES': [ + 'drf_turbo.renderers.JSONRenderer', + ] + } + + or + + REST_FRAMEWORK = { + 'DEFAULT_RENDERERS_CLASSES': [ + 'drf_turbo.renderers.UJSONRenderer', + ] + } + + or + + REST_FRAMEWORK = { + 'DEFAULT_RENDERERS_CLASSES': [ + 'drf_turbo.renderers.ORJSONRenderer', + ] + } + +**NOTE**: ujson must be installed to use UJSONRenderer. + +**NOTE**: orjson must be installed to use ORJSONRenderer. + + + +Responses +--------- + +An ``HttpResponse`` subclass that helps to create a JSON-encoded response. Its default Content-Type header is set to application/json. + +.. code:: python + + from rest_framework.views import APIView + import drf_turbo as dt + + class UserInfo(APIView): + def get(self,request): + data = {"username":"test"} + return dt.JsonResponse(data,status=200) + + or + + class UserInfo(APIView): + def get(self,request): + data = {"username":"test"} + return dt.UJSONResponse(data,status=200) + + or + + class UserInfo(APIView): + def get(self,request): + data = {"username":"test"} + return dt.ORJSONResponse(data,status=200) + +**NOTE**: ujson must be installed to use UJSONResponse. + +**NOTE**: orjson must be installed to use ORJSONResponse. + + +Also drf-turbo provides an easy way to return a success or error response using ``SuccessResponse`` or ``ErrorResponse`` clasess. + +for example : + +.. code:: python + + class UserInfo(APIView): + def get(self,request): + data = {"username":"test"} + serializer = UserSerializer(data=data) + if not serializer.is_valid(): + return dt.ErrorResponse(serializer.errors) + # returned response : {'message':'Bad request', data : ``serializer_errros``, 'error': True} with status = 400 + return dt.SuccessResponse(data) + # returned response : {'message':'Success', data : {"username":"test"} , 'error': False} with status = 200 + + + + + +OpenApi(Swagger) +---------------- + +Add drf-turbo to installed apps in ``settings.py`` + +.. code:: python + + INSTALLED_APPS = [ + # ALL YOUR APPS + 'drf_turbo', + ] + + +and then register our openapi AutoSchema with DRF. + +.. code:: python + + REST_FRAMEWORK = { + # YOUR SETTINGS + 'DEFAULT_SCHEMA_CLASS': 'drf_turbo.openapi.AutoSchema', + } + + +and finally add these lines in ``urls.py`` + +.. code:: python + + from django.views.generic import TemplateView + from rest_framework.schemas import get_schema_view as schema_view + from drf_turbo.openapi import SchemaGenerator + + urlpatterns = [ + # YOUR PATTERNS + path('openapi', schema_view( + title="Your Project", + description="API for all things …", + version="1.0.0", + generator_class=SchemaGenerator, + public=True, + ), name='openapi-schema'), + path('docs/', TemplateView.as_view( + template_name='docs.html', + extra_context={'schema_url':'openapi-schema'} + ), name='swagger-ui'), + ] + +Now go to http://127.0.0.1:8000/docs + +Credits +------- + +This package was created with Cookiecutter_ and the `audreyr/cookiecutter-pypackage`_ project template. + +.. _Cookiecutter: https://github.com/audreyr/cookiecutter +.. _`audreyr/cookiecutter-pypackage`: https://github.com/audreyr/cookiecutter-pypackage diff --git a/docs/Makefile b/docs/Makefile new file mode 100755 index 0000000..4edd235 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,192 @@ +# Makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +PAPER = +BUILDDIR = _build + +# User-friendly check for sphinx-build +ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) +$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) +endif + +# Internal variables. +PAPEROPT_a4 = -D latex_paper_size=a4 +PAPEROPT_letter = -D latex_paper_size=letter +ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . +# the i18n builder cannot share the environment and doctrees with the others +I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . + +.PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest coverage gettext + +help: + @echo "Please use \`make ' where is one of" + @echo " html to make standalone HTML files" + @echo " dirhtml to make HTML files named index.html in directories" + @echo " singlehtml to make a single large HTML file" + @echo " pickle to make pickle files" + @echo " json to make JSON files" + @echo " htmlhelp to make HTML files and a HTML help project" + @echo " qthelp to make HTML files and a qthelp project" + @echo " applehelp to make an Apple Help Book" + @echo " devhelp to make HTML files and a Devhelp project" + @echo " epub to make an epub" + @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" + @echo " latexpdf to make LaTeX files and run them through pdflatex" + @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" + @echo " text to make text files" + @echo " man to make manual pages" + @echo " texinfo to make Texinfo files" + @echo " info to make Texinfo files and run them through makeinfo" + @echo " gettext to make PO message catalogs" + @echo " changes to make an overview of all changed/added/deprecated items" + @echo " xml to make Docutils-native XML files" + @echo " pseudoxml to make pseudoxml-XML files for display purposes" + @echo " linkcheck to check all external links for integrity" + @echo " doctest to run all doctests embedded in the documentation (if enabled)" + @echo " coverage to run coverage check of the documentation (if enabled)" + +clean: + rm -rf $(BUILDDIR)/* + +html: + $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." + +dirhtml: + $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." + +singlehtml: + $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml + @echo + @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." + +pickle: + $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle + @echo + @echo "Build finished; now you can process the pickle files." + +json: + $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json + @echo + @echo "Build finished; now you can process the JSON files." + +htmlhelp: + $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp + @echo + @echo "Build finished; now you can run HTML Help Workshop with the" \ + ".hhp project file in $(BUILDDIR)/htmlhelp." + +qthelp: + $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp + @echo + @echo "Build finished; now you can run "qcollectiongenerator" with the" \ + ".qhcp project file in $(BUILDDIR)/qthelp, like this:" + @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/serpy.qhcp" + @echo "To view the help file:" + @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/serpy.qhc" + +applehelp: + $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp + @echo + @echo "Build finished. The help book is in $(BUILDDIR)/applehelp." + @echo "N.B. You won't be able to view it unless you put it in" \ + "~/Library/Documentation/Help or install it in your application" \ + "bundle." + +devhelp: + $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp + @echo + @echo "Build finished." + @echo "To view the help file:" + @echo "# mkdir -p $$HOME/.local/share/devhelp/serpy" + @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/serpy" + @echo "# devhelp" + +epub: + $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub + @echo + @echo "Build finished. The epub file is in $(BUILDDIR)/epub." + +latex: + $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex + @echo + @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." + @echo "Run \`make' in that directory to run these through (pdf)latex" \ + "(use \`make latexpdf' here to do that automatically)." + +latexpdf: + $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex + @echo "Running LaTeX files through pdflatex..." + $(MAKE) -C $(BUILDDIR)/latex all-pdf + @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." + +latexpdfja: + $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex + @echo "Running LaTeX files through platex and dvipdfmx..." + $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja + @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." + +text: + $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text + @echo + @echo "Build finished. The text files are in $(BUILDDIR)/text." + +man: + $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man + @echo + @echo "Build finished. The manual pages are in $(BUILDDIR)/man." + +texinfo: + $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo + @echo + @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." + @echo "Run \`make' in that directory to run these through makeinfo" \ + "(use \`make info' here to do that automatically)." + +info: + $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo + @echo "Running Texinfo files through makeinfo..." + make -C $(BUILDDIR)/texinfo info + @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." + +gettext: + $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale + @echo + @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." + +changes: + $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes + @echo + @echo "The overview file is in $(BUILDDIR)/changes." + +linkcheck: + $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck + @echo + @echo "Link check complete; look for any errors in the above output " \ + "or in $(BUILDDIR)/linkcheck/output.txt." + +doctest: + $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest + @echo "Testing of doctests in the sources finished, look at the " \ + "results in $(BUILDDIR)/doctest/output.txt." + +coverage: + $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage + @echo "Testing of coverage in the sources finished, look at the " \ + "results in $(BUILDDIR)/coverage/python.txt." + +xml: + $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml + @echo + @echo "Build finished. The XML files are in $(BUILDDIR)/xml." + +pseudoxml: + $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml + @echo + @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." diff --git a/docs/_templates/page.html b/docs/_templates/page.html new file mode 100755 index 0000000..829b943 --- /dev/null +++ b/docs/_templates/page.html @@ -0,0 +1,8 @@ +{% extends "!page.html" %} +{% block extrahead %} + {{ super() }} + +{% endblock %} +{% block menu %} + {{ super() }} +{% endblock %} diff --git a/docs/api.rst b/docs/api.rst new file mode 100755 index 0000000..8342789 --- /dev/null +++ b/docs/api.rst @@ -0,0 +1,105 @@ +************* +API Reference +************* + +Serializer +========== + +.. currentmodule:: drf_turbo + +.. autoclass:: BaseSerializer + :members: + +.. autoclass:: Serializer + :show-inheritance: + :inherited-members: + :members: + +.. autoclass:: ModelSerializer + :show-inheritance: + :inherited-members: + :members: + +Fields +====== + +.. autoclass:: Field + :members: + +.. autoclass:: drf_turbo.StrField + :members: + +.. autoclass:: drf_turbo.EmailField + :members: + +.. autoclass:: drf_turbo.URLField + :members: + +.. autoclass:: drf_turbo.RegexField + :members: + +.. autoclass:: drf_turbo.IPField + :members: + +.. autoclass:: drf_turbo.UUIDField + :members: + +.. autoclass:: drf_turbo.PasswordField + :members: + +.. autoclass:: drf_turbo.SlugField + :members: + +.. autoclass:: IntField + :members: + +.. autoclass:: FloatField + :members: + +.. autoclass:: DecimalField + :members: + +.. autoclass:: BoolField + :members: + +.. autoclass:: ChoiceField + :members: + +.. autoclass:: MultipleChoiceField + :members: + +.. autoclass:: DateTimeField + :members: + +.. autoclass:: DateField + :members: + +.. autoclass:: TimeField + :members: + +.. autoclass:: FileField + :members: + +.. autoclass:: ArrayField + :members: + +.. autoclass:: DictField + :members: + +.. autoclass:: JSONField + :members: + +.. autoclass:: RelatedField + :members: + +.. autoclass:: ManyRelatedField + :members: + +.. autoclass:: ConstantField + :members: + +.. autoclass:: RecursiveField + :members: + +.. autoclass:: MethodField + :members: diff --git a/docs/authors.rst b/docs/authors.rst new file mode 100755 index 0000000..e122f91 --- /dev/null +++ b/docs/authors.rst @@ -0,0 +1 @@ +.. include:: ../AUTHORS.rst diff --git a/docs/conf.py b/docs/conf.py new file mode 100755 index 0000000..2d67705 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +# flake8: noqa +# +# serpy documentation build configuration file, created by +# sphinx-quickstart on Thu Apr 2 22:25:55 2015. +# +# This file is execfile()d with the current directory set to its +# containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +import drf_turbo + +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.viewcode", + "sphinx.ext.intersphinx", + "sphinx_issues", +] + +primary_domain = "py" +default_role = "py:obj" + +github_user = "Mng-dev-ai" +github_repo = "drf-turbo" + +issues_github_path = f"{github_user}/{github_repo}" + +# The master toctree document. +master_doc = "index" +language = "en" +html_domain_indices = False +source_suffix = ".rst" +project = "drf-turbo" +copyright = "2021, Michael Gendy" +version = release = drf_turbo.__version__ +templates_path = ["_templates"] +exclude_patterns = ["_build"] +author = "Michael Gendy" +autoclass_content = "both" + +# Theme +html_theme = 'furo' + + +html_sidebars = { + "*": [ + "sidebar/scroll-start.html", + "sidebar/brand.html", + "sidebar/search.html", + "sidebar/navigation.html", + "sidebar/ethical-ads.html", + "sidebar/scroll-end.html", + ] +} + + + diff --git a/docs/contributing.rst b/docs/contributing.rst new file mode 100755 index 0000000..e582053 --- /dev/null +++ b/docs/contributing.rst @@ -0,0 +1 @@ +.. include:: ../CONTRIBUTING.rst diff --git a/docs/index.rst b/docs/index.rst new file mode 100755 index 0000000..1b8fb6e --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,21 @@ +.. raw:: html + + Fork me on GitHub + + +.. include:: ../README.rst + +Contents: + +.. toctree:: + :maxdepth: 2 + + installation + api + performance + contributing + authors diff --git a/docs/installation.rst b/docs/installation.rst new file mode 100755 index 0000000..cf274af --- /dev/null +++ b/docs/installation.rst @@ -0,0 +1,51 @@ +.. highlight:: shell + +============ +Installation +============ + + +Stable release +-------------- + +To install drf-turbo, run this command in your terminal: + +.. code-block:: console + + $ pip install drf-turbo + +This is the preferred method to install drf-turbo, as it will always install the most recent stable release. + +If you don't have `pip`_ installed, this `Python installation guide`_ can guide +you through the process. + +.. _pip: https://pip.pypa.io +.. _Python installation guide: http://docs.python-guide.org/en/latest/starting/installation/ + + +From sources +------------ + +The sources for drf-turbo can be downloaded from the `Github repo`_. + +You can either clone the public repository: + +.. code-block:: console + + $ git clone git://github.com/Mng-dev-ai/drf-turbo + +Or download the `tarball`_: + +.. code-block:: console + + $ curl -OJL https://github.com/Mng-dev-ai/drf-turbo/tarball/master + +Once you have a copy of the source, you can install it with: + +.. code-block:: console + + $ python setup.py install + + +.. _Github repo: https://github.com/Mng-dev-ai/drf-turbo +.. _tarball: https://github.com/Mng-dev-ai/drf-turbo/tarball/master diff --git a/docs/performance.rst b/docs/performance.rst new file mode 100755 index 0000000..851fc9f --- /dev/null +++ b/docs/performance.rst @@ -0,0 +1,20 @@ +********************** +Performance Benchmarks +********************** + + +Using Pydantic's own benchmark : +:: + + pydantic best=103.525μs/iter avg=106.093μs/iter stdev=2.720μs/iter version=1.8.2 + attrs + cattrs best=67.449μs/iter avg=67.766μs/iter stdev=0.388μs/iter version=21.2.0 + valideer best=90.025μs/iter avg=91.712μs/iter stdev=1.579μs/iter version=0.4.2 + marshmallow best=198.683μs/iter avg=201.499μs/iter stdev=2.192μs/iter version=3.14.0 + voluptuous best=195.394μs/iter avg=197.416μs/iter stdev=2.399μs/iter version=0.12.2 + trafaret best=221.880μs/iter avg=223.723μs/iter stdev=1.050μs/iter version=2.1.0 + schematics best=728.707μs/iter avg=745.313μs/iter stdev=11.523μs/iter version=2.1.1 + django-rest-framework best=856.090μs/iter avg=885.471μs/iter stdev=39.377μs/iter version=3.12.4 + drf_turbo best=111.754μs/iter avg=113.794μs/iter stdev=1.274μs/iter version=0.3.8 + cerberus best=1561.922μs/iter avg=1638.844μs/iter stdev=44.257μs/iter version=1.3.4 + + diff --git a/docs/readme.rst b/docs/readme.rst new file mode 100755 index 0000000..72a3355 --- /dev/null +++ b/docs/readme.rst @@ -0,0 +1 @@ +.. include:: ../README.rst diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100755 index 0000000..ede1aaf --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,10 @@ +forbiddenfruit +djangorestframework +pyyaml +cython +psycopg2-binary +sphinx-issues +furo + + + diff --git a/drf_turbo/__init__.py b/drf_turbo/__init__.py new file mode 100755 index 0000000..a4153cb --- /dev/null +++ b/drf_turbo/__init__.py @@ -0,0 +1,64 @@ +from drf_turbo.serializer import BaseSerializer,Serializer,ModelSerializer +from drf_turbo.fields import ( + Field,StrField,EmailField,URLField,RegexField,IPField,PasswordField,UUIDField,SlugField,IntField,FloatField,DecimalField,BoolField,ChoiceField,MultipleChoiceField,DateTimeField,DateField,TimeField,FileField,ArrayField,DictField,JSONField,RelatedField,ManyRelatedField,ConstantField,RecursiveField,MethodField +) +from drf_turbo.exceptions import ValidationError,ParseError +from drf_turbo.response import JSONResponse,UJSONResponse,ORJSONResponse,SuccessResponse,ErrorResponse +from drf_turbo.parsers import JSONParser,UJSONParser,ORJSONParser +from drf_turbo.renderers import JSONRenderer,UJSONRenderer,ORJSONRenderer + +__author__ = """Michael Gendy""" +__email__ = 'mngback@gmail.com' +__version__ = '0.1.1' + +__all__ = [ + 'BaseSerializer', + 'Serializer', + 'ModelSerializer', + 'Field', + 'StrField', + 'EmailField', + 'URLField', + 'RegexField', + 'IPField', + 'PasswordField', + 'UUIDField', + 'SlugField', + 'IntField', + 'FloatField', + 'DecimalField', + 'BoolField', + 'ChoiceField', + 'MultipleChoiceField', + 'DateTimeField', + 'DateField', + 'TimeField', + 'FileField', + 'ArrayField', + 'DictField', + 'JSONField', + 'RelatedField', + 'ManyRelatedField', + 'ConstantField', + 'RecursiveField', + 'MethodField', + 'ValidationError', + 'ParseError', + 'JSONResponse', + 'UJSONResponse', + 'ORJSONResponse', + 'SuccessResponse', + 'ErrorResponse', + 'JSONParser', + 'UJSONParser', + 'ORJSONParser', + 'JSONRenderer', + 'UJSONRenderer', + 'ORJSONRenderer', + +] + + + + + diff --git a/drf_turbo/cython_metaclass.h b/drf_turbo/cython_metaclass.h new file mode 100755 index 0000000..cc620a4 --- /dev/null +++ b/drf_turbo/cython_metaclass.h @@ -0,0 +1,107 @@ +/***************************************************************************** +* Copyright (C) 2015 Jeroen Demeyer +* +* This program is free software: you can redistribute it and/or modify +* it under the terms of the GNU General Public License as published by +* the Free Software Foundation, either version 2 of the License, or +* (at your option) any later version. +* http://www.gnu.org/licenses/ +*****************************************************************************/ + +/* Tuple (None, None, None), initialized as needed */ +static PyObject* NoneNoneNone; + +/* All args flags of a PyMethod */ +#define METH_ALLARGS (METH_VARARGS|METH_KEYWORDS|METH_NOARGS|METH_O) + +/* Given an unbound method "desc" (this is not checked!) with only a + * single "self" argument, call "desc(self)" without checking "self". + * This can in particular be used to call any method as class or + * static method. */ +static CYTHON_INLINE PyObject* PyMethodDescr_CallSelf(PyMethodDescrObject* desc, PyObject* self) +{ + PyMethodDef* meth = desc->d_method; + + /* This must be a METH_NOARGS method */ + if (meth == NULL || (meth->ml_flags & METH_ALLARGS) != METH_NOARGS) + { + PyErr_SetString(PyExc_TypeError, + "PyMethodDescr_CallSelf requires a method without arguments"); + return NULL; + } + + return meth->ml_meth(self, NULL); +} + +/* + * This function calls PyType_Ready(t) and then calls + * t.__getmetaclass__(None) (if that method exists) which should + * return the metaclass for t. Then type(t) is set to this metaclass + * and metaclass.__init__(t, None, None, None) is called. + */ +static CYTHON_INLINE int Sage_PyType_Ready(PyTypeObject* t) +{ + int r = PyType_Ready(t); + if (r < 0) + return r; + + /* Set or get metaclass (the type of t) */ + PyTypeObject* metaclass; + + PyObject* getmetaclass; + getmetaclass = PyObject_GetAttrString((PyObject*)t, "__getmetaclass__"); + if (getmetaclass) + { + /* Call getmetaclass with self=None */ + metaclass = (PyTypeObject*)(PyMethodDescr_CallSelf((PyMethodDescrObject*)getmetaclass, Py_None)); + Py_DECREF(getmetaclass); + if (!metaclass) + return -1; + + if (!PyType_Check(metaclass)) + { + PyErr_SetString(PyExc_TypeError, + "__getmetaclass__ did not return a type"); + return -1; + } + + /* Now, set t.__class__ to metaclass */ + Py_TYPE(t) = metaclass; + PyType_Modified(t); + } + else + { + /* No __getmetaclass__ method: read metaclass... */ + PyErr_Clear(); + metaclass = Py_TYPE(t); + } + + /* Now call metaclass.__init__(t, None, None, None) unless + * we would be calling type.__init__ */ + initproc init = metaclass->tp_init; + if (init == NULL || init == PyType_Type.tp_init) + return 0; + + /* Safety check: since we didn't call tp_new of metaclass, + * we cannot safely call tp_init if the size of the structure + * differs. */ + if (metaclass->tp_basicsize != PyType_Type.tp_basicsize) + { + PyErr_SetString(PyExc_TypeError, + "metaclass is not compatible with 'type' (you cannot use cdef attributes in Cython metaclasses)"); + return -1; + } + + /* Initialize a tuple (None, None, None) */ + if (!NoneNoneNone) + { + NoneNoneNone = PyTuple_Pack(3, Py_None, Py_None, Py_None); + if (!NoneNoneNone) return -1; + } + return init((PyObject*)t, NoneNoneNone, NULL); +} + + +/* Use the above function in Cython code instead of the default + * PyType_Ready() function */ +#define PyType_Ready(t) Sage_PyType_Ready(t) diff --git a/drf_turbo/cython_metaclass.pxd b/drf_turbo/cython_metaclass.pxd new file mode 100755 index 0000000..189eb04 --- /dev/null +++ b/drf_turbo/cython_metaclass.pxd @@ -0,0 +1,2 @@ +cdef extern from "cython_metaclass.h": + PyMethodDescr_CallSelf(desc, self) diff --git a/drf_turbo/cython_metaclass.pyx b/drf_turbo/cython_metaclass.pyx new file mode 100755 index 0000000..11294c2 --- /dev/null +++ b/drf_turbo/cython_metaclass.pyx @@ -0,0 +1,123 @@ +""" +Metaclasses for Cython extension types + +Cython does not support metaclasses, but this module can be used to +implement metaclasses for extension types. + +.. WARNING:: + + This module has many caveats and you can easily get segfaults if you + make a mistake. It relies on undocumented Python and Cython + behaviour, so things might break in future versions. + +How to use +========== + +To enable this metaclass mechanism, you need to put +``cimport sage.cpython.cython_metaclass`` in your module (in the ``.pxd`` +file if you are using one). + +In the extension type (a.k.a. ``cdef class``) for which you want to +define a metaclass, define a method ``__getmetaclass__`` with a single +unused argument. This method should return a type to be used as +metaclass: + +.. code-block:: cython + + cimport sage.cpython.cython_metaclass + cdef class MyCustomType(object): + def __getmetaclass__(_): + from foo import MyMetaclass + return MyMetaclass + +The above ``__getmetaclass__`` method is analogous to +``__metaclass__ = MyMetaclass`` in Python 2. + +.. WARNING:: + + ``__getmetaclass__`` must be defined as an ordinary method taking a + single argument, but this argument should not be used in the + method (it will be ``None``). + +When a type ``cls`` is being constructed with metaclass ``meta``, +then ``meta.__init__(cls, None, None, None)`` is called from Cython. +In Python, this would be ``meta.__init__(cls, name, bases, dict)``. + +.. WARNING:: + + The ``__getmetaclass__`` method is called while the type is being + created during the import of the module. Therefore, + ``__getmetaclass__`` should not refer to any global objects, + including the type being created or other types defined or imported + in the module (unless you are very careful). Note that non-imported + ``cdef`` functions are not Python objects, so those are safe to call. + + The same warning applies to the ``__init__`` method of the + metaclass. + +.. WARNING:: + + The ``__new__`` method of the metaclass (including the ``__cinit__`` + method for Cython extension types) is never called if you're using + this from Cython. In particular, the metaclass cannot have any + attributes or virtual methods. + +EXAMPLES:: + + sage: cython(''' + ....: cimport sage.cpython.cython_metaclass + ....: cdef class MyCustomType(object): + ....: def __getmetaclass__(_): + ....: class MyMetaclass(type): + ....: def __init__(*args): + ....: print("Calling MyMetaclass.__init__{}".format(args)) + ....: return MyMetaclass + ....: + ....: cdef class MyDerivedType(MyCustomType): + ....: pass + ....: ''') + Calling MyMetaclass.__init__(, None, None, None) + Calling MyMetaclass.__init__(, None, None, None) + sage: MyCustomType.__class__ + + sage: class MyPythonType(MyDerivedType): + ....: pass + Calling MyMetaclass.__init__(, 'MyPythonType', (,), {...}) + +Implementation +============== + +All this is implemented by defining + +.. code-block:: c + + #define PyTypeReady(t) Sage_PyType_Ready(t) + +and then implementing the function ``Sage_PyType_Ready(t)`` which first +calls ``PyType_Ready(t)`` and then handles the metaclass stuff. + +TESTS: + +Check that a proper exception is raised if ``__getmetaclass__`` +returns a non-type:: + + sage: cython(''' + ....: cimport sage.cpython.cython_metaclass + ....: cdef class MyCustomType(object): + ....: def __getmetaclass__(_): + ....: return 2 + ....: ''') + Traceback (most recent call last): + ... + TypeError: __getmetaclass__ did not return a type +""" + +#***************************************************************************** +# Copyright (C) 2015 Jeroen Demeyer +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# http://www.gnu.org/licenses/ +#***************************************************************************** diff --git a/drf_turbo/exceptions.pyx b/drf_turbo/exceptions.pyx new file mode 100755 index 0000000..e48a209 --- /dev/null +++ b/drf_turbo/exceptions.pyx @@ -0,0 +1,48 @@ +from drf_turbo.utils import get_execption_detail + +cdef class DrfTurboException(Exception): + """Base class for all fast_rest-related errors.""" + status_code = 500 + default_detail = 'A server error occurred.' + default_code = 'error' + + def __cinit__(self, detail=None, code=None) -> None: + if detail is None: + detail = self.default_detail + if code is None: + code = self.default_code + + self.detail = get_execption_detail(detail) + + def __str__(self) -> str: + return str(self.detail) + +cdef class ValidationError(DrfTurboException): + default_detail = 'Invalid input.' + default_code = 'invalid' + + def __cinit__(self, detail=None, code=None) -> None: + if detail is None: + detail = self.default_detail + if code is None: + code = self.default_code + + if isinstance(detail, tuple): + detail = list(detail) + elif not isinstance(detail, dict) and not isinstance(detail, list): + detail = [detail] + + self.detail = get_execption_detail(detail) + +class StringNotCollectionError(DrfTurboException,TypeError): + pass + + +class OnlyAndExcludeError(DrfTurboException): + pass + + +class ParseError(DrfTurboException): + status_code = 400 + default_detail = 'Malformed request.' + default_code = 'parse_error' diff --git a/drf_turbo/fields.pxd b/drf_turbo/fields.pxd new file mode 100755 index 0000000..ae895d0 --- /dev/null +++ b/drf_turbo/fields.pxd @@ -0,0 +1,178 @@ +cdef object NO_DEFAULT + +cdef class SkipField(Exception): + pass + +cdef class Field : + cdef public : + str attr + bint call + bint required + bint write_only + bint read_only + bint allow_null + str label + object default_value + object initial + str help_text + dict style + str field_name + object root + object validators + dict error_messages + list attrs + + cpdef serialize(self, value , dict context) + cpdef deserialize(self,data, dict context) + cpdef method_getter(self,field_name, root) + cpdef void bind(self,basestring name, object root) + cpdef run_validation(self,object data,dict context) + cpdef get_initial(self) + cpdef get_attribute(self, instance,attr=*) + cpdef get_default_value(self) + cpdef validate_empty_values(self, data) + cpdef long validate_or_raise(self,value) except -1 + + +cdef class StrField(Field): + cdef public: + allow_blank + trim_whitespace + max_length + min_length + +cdef class EmailField(StrField): + cdef public: + bint to_lower + +cdef class URLField(StrField): + pass + +cdef class RegexField(StrField): + pass + +cdef class IPField(StrField): + pass + +cdef class PasswordField(StrField): + pass + +cdef class UUIDField(StrField): + pass + +cdef class SlugField(Field): + cdef public: + allow_unicode + +cdef class IntField(Field): + cdef public : + max_value + min_value + +cdef class FloatField(Field): + cdef public : + max_value + min_value + +cdef class DecimalField(Field): + cdef public: + max_digits + decimal_places + max_value + min_value + rounding + coerce_to_string + max_whole_digits + + cdef validate_precision(self,value) + cdef quantize(self,value) + +cdef class BoolField(Field): + pass + +cdef class ChoiceField(Field): + cdef public : + choices + choice_strings_to_values + choice_strings_to_display + allow_blank + +cdef class MultipleChoiceField(ChoiceField): + cdef public: + allow_empty + +cdef class DateTimeField(Field): + cdef public: + format + input_formats + default_timezone + timezone + + cpdef get_default_timezone(self) + cpdef enforce_timezone(self, value) + +cdef class DateField(Field): + cdef public: + format + input_formats + +cdef class TimeField(Field): + cdef public: + format + input_formats + +cdef class FileField(Field): + cdef public : + max_length + allow_empty_file + +cdef class ArrayField(Field): + cdef public : + child + allow_empty + min_items + max_items + exact_items + + cpdef run_child_validation(self,data,dict context) + +cdef class DictField(Field): + cdef public : + child + allow_empty + + cpdef run_child_validation(self,dict data,dict context) + + +cdef class JSONField(Field): + cdef public : + binary + encoder + decoder + + +cdef class RelatedField(Field): + cdef public: + queryset + +cdef class ManyRelatedField(Field): + cdef public : + child_relation + allow_empty + +cdef class ConstantField(Field): + cdef public : + constant + +cdef class RecursiveField(Field): + cdef public : + many + context + only + exclude + cpdef serialize(self,value,dict context) + +cdef class MethodField(Field): + cdef public : + method_name + diff --git a/drf_turbo/fields.pyx b/drf_turbo/fields.pyx new file mode 100755 index 0000000..51b70a5 --- /dev/null +++ b/drf_turbo/fields.pyx @@ -0,0 +1,1450 @@ +cimport cython + +from drf_turbo.utils import is_iterable_and_not_string,get_error_detail,is_collection,get_attribute +from drf_turbo.exceptions import * +from django.core.validators import EmailValidator,URLValidator,RegexValidator,MaxLengthValidator,MinLengthValidator,MinValueValidator,MaxValueValidator,ProhibitNullCharactersValidator +from django.core.exceptions import ValidationError as DjangoValidationError +import ipaddress +import copy +import json +from django.core.exceptions import ObjectDoesNotExist +import uuid +from django.utils.dateparse import ( + parse_date, parse_datetime, parse_time +) +import decimal,re +from django.utils.encoding import smart_str +from rest_framework.settings import api_settings +import datetime +from rest_framework import ( + ISO_8601 +) +from django.conf import settings +from django.utils import timezone +from django.utils.timezone import utc +from pytz.exceptions import InvalidTimeError + +cdef object NO_DEFAULT = object() + + +cdef class SkipField(Exception): + pass + + +cdef class Field : + """ + Basic field from which other fields should extend. It applies no + formatting by default, and should only be used in cases where + data does not need to be formatted before being serialized or deserialized. + On error, the name of the field will be returned. + + :param str attr: The name of the attribute to get the value from when serializing. + :param bool call: Whether the value should be called after it is retrieved + from the object. Useful if an object has a method to be serialized. + :param bool required: Raise a `ValidationError` if the field value + is not supplied during deserialization. + :param bool write_only: If `True` skip this field during serialization, otherwise + its value will be present in the serialized data. + :param bool read_only: If `True` skip this field during deserialization, otherwise + its value will be present in the deserialized object. + :param str label: A label to use as the name of the serialized field + instead of using the attribute name of the field. + :param list validators: A list of validators to apply against incoming data during deserialization. + :param str field_name: The name of field. + :param str root: The root(parent) of field. + :param default_value: Default value to be used during serialization and deserialization. + :param initial: The initial value for the field. + """ + + is_method_field = False + default_error_messages = { + 'required': 'This field is required.', + 'null': 'This field may not be null.', + } + _initial = None + + def __init__( + self, + basestring attr = None, + bint call= False, + bint required= True, + bint write_only = False, + bint read_only = False, + bint allow_null = False, + basestring label = None, + basestring help_text = None, + dict style = None , + object validators = None, + object default_value = NO_DEFAULT, + object initial = NO_DEFAULT, + basestring field_name = None, + object root = None , + dict error_messages = None, + + ): + required = False if default_value is not NO_DEFAULT else required + assert not (read_only and write_only), 'May not set both `read_only` and `write_only`' + assert not (required and default_value is not NO_DEFAULT), 'May not set both `required` and `default_value`' + + self.attr = attr + self.call = call + self.required = required + self.write_only = write_only + self.read_only = (read_only or call or + (attr is not None and '.' in attr)) # type: ignore + self.allow_null = allow_null + self.label = label + self.default_value = default_value + self.initial = self._initial if (initial is NO_DEFAULT) else initial + self.help_text = help_text + self.style = {} if style is None else style + self.field_name = field_name + self.root = root + if validators is None: + self.validators = [] + elif callable(validators): + self.validators = [validators] + elif is_iterable_and_not_string(validators) : + self.validators = list(validators) + else: + raise ValueError( + "The 'validators' parameter must be a callable " + "or a collection of callables." + ) + + messages = {} + for cls in reversed(self.__class__.__mro__): + messages.update(getattr(cls, 'default_error_messages', {})) + messages.update(error_messages or {}) + self.error_messages = messages + + def raise_if_fail(self, key: str, **kwargs) : + """Helper method to make a `ValidationError` with an error message + from ``self.error_messages``. + """ + try: + msg = self.error_messages[key] + except KeyError as error: + raise AssertionError(error) + if isinstance(msg, (str, bytes)): + msg = msg.format(**kwargs) + return ValidationError(msg) + + + cpdef serialize(self,value, dict context): + """ + Transform the *outgoing* native value into primitive data + + :param value: The outgoing value. + :param context: The context for the request. + """ + return value + + + cpdef deserialize(self,data, dict context): + """ + Transform the *incoming* primitive data into a native value. + + :param data: The incoming data. + :param context: The context for the request. + """ + return data + + + cpdef method_getter(self,field_name, root) : + """ + Returns a function that fetches an attribute from an object. + + :field_name: The name of the attribute to get. + :root: The root of the field. + """ + return None + + + cpdef void bind(self,basestring field_name, object root): + """ + Update the field name and root for the field instance. + + :field_name: The name of the field. + :root: The root of the field. + """ + self.field_name = field_name + self.root = root + if self.label is None: + self.label = field_name.replace('_', ' ').capitalize() + + if self.attr is None: + self.attr = field_name + + self.attrs = self.attr.split('.') if self.attr else [] + + + cpdef get_default_value(self): + """ + Return the default value for this field. + """ + if self.default_value is NO_DEFAULT or getattr(self.root, 'partial', False): + raise SkipField() + if callable(self.default_value): + return self.default_value() + return self.default_value + + cpdef get_initial(self): + """ + Return the initial value for this field. + """ + if callable(self.initial): + return self.initial() + return self.initial + + cpdef get_attribute(self, instance , attr=None): + """ + Return the value of the field from the provided instance. + """ + try: + if attr is None: + return get_attribute(instance, self.attrs) + return get_attribute(instance, attr) + except (KeyError, AttributeError) as exc: + if self.default_value is not NO_DEFAULT: + return self.get_default_value() + if self.allow_null: + return None + if not self.required: + raise SkipField() + msg = ( + 'Got {exc_type} when attempting to get a value for field' + ) + raise type(exc)(msg) + + cpdef validate_empty_values(self, data): + """ + Validate empty values, and either: + * Raise `ValidationError`, indicating invalid data. + * Return (True, data), indicating an empty value that should be + returned without any further validation being applied. + * Return (False, data), indicating a non-empty value, that should + have validation applied as normal. + """ + if self.read_only: + return (True, self.get_default_value()) + + if data is NO_DEFAULT: + if getattr(self.root, 'partial', False): + raise SkipField() + if self.required: + raise ValidationError('This field is required.') + return (True, self.get_default_value()) + + if data is None: + if not self.allow_null: + raise ValidationError('This field may not be null.') + return (True, None) + + return (False, data) + + cpdef run_validation(self,object data,dict context) : + """ + Validate an input data. + """ + + (is_empty_value, data) = self.validate_empty_values(data) + if is_empty_value: + return data + value = self.deserialize(data,context) + self.validate_or_raise(value) + return value + + + cpdef long validate_or_raise(self,value) except -1 : + """ + Validate the value and raise a `ValidationError` if validation fails. + """ + + cdef list errors = [] + for validator in self.validators : + try : + validator(value) + except ValidationError as exc: + if isinstance(exc.detail, dict): + raise + errors.extend(exc.detail) + except DjangoValidationError as exc: + errors.extend(get_error_detail(exc)) + if errors: + raise ValidationError(errors) + + + + + +cdef class StrField(Field): + """" + A field that validates input as an string. + + :param kwargs: The same keyword arguments that :class:`Field` receives. + """ + + default_error_messages = { + 'blank': 'May not be blank.', + 'invalid': 'Not a valid string.' + } + _initial = '' + def __init__(self,**kwargs) : + self.allow_blank = kwargs.pop('allow_blank', False) + self.trim_whitespace = kwargs.pop('trim_whitespace', True) + self.max_length = kwargs.pop('max_length', None) + self.min_length = kwargs.pop('min_length', None) + super().__init__(**kwargs) + if self.max_length is not None: + self.validators.append( + MaxLengthValidator(self.max_length)) + if self.min_length is not None: + self.validators.append(MinLengthValidator(self.min_length)) + self.validators.append(ProhibitNullCharactersValidator()) + + + cpdef serialize(self,value,dict context) : + return str(value) + + cpdef deserialize(self,data,dict context) : + if data == '' or (self.trim_whitespace and str(data).strip() == ''): + if not self.allow_blank: + raise self.raise_if_fail('blank') + if isinstance(data, bool) or not isinstance(data, (str, int, float,)): + raise self.raise_if_fail('invalid') + data = str(data) + return data.strip() if self.trim_whitespace else data + +@cython.final +cdef class EmailField(StrField): + """ + A field that validates input as an E-Mail address. + + :param to_lower: If True, convert the value to lowercase before validating. + :param kwargs: The same keyword arguments that :class:`Field` receives. + """ + + default_error_messages = { + 'invalid': 'Enter a valid email address.' + } + def __init__(self, **kwargs): + self.to_lower = kwargs.pop('to_lower', False) + super().__init__(**kwargs) + validator = EmailValidator(message=self.error_messages['invalid']) + self.validators.append(validator) + + cpdef inline serialize(self,value,dict context): + if self.to_lower : + return value.lower() + return value + + cpdef inline deserialize(self,data,dict context): + if self.to_lower : + return data.lower() + return data + +@cython.final +cdef class URLField(StrField): + """ + A field that validates input as an URL. + + :param kwargs: The same keyword arguments that :class:`Field` receives. + """ + + default_error_messages = { + 'invalid': 'Enter a valid URL.' + } + + def __init__(self, **kwargs): + super().__init__(**kwargs) + validator = URLValidator(message=self.error_messages['invalid']) + self.validators.append(validator) + +@cython.final +cdef class RegexField(StrField): + """ + A field that validates input against a given regular expression. + + :param kwargs: The same keyword arguments that :class:`Field` receives. + """ + + default_error_messages = { + 'invalid': 'This value does not match the required pattern.' + } + + def __init__(self, regex, **kwargs): + super().__init__(**kwargs) + validator = RegexValidator(regex, message=self.error_messages['invalid']) + self.validators.append(validator) + +@cython.final +cdef class IPField(StrField): + """ + A field that validates that input is an IP address. + + :param kwargs: The same keyword arguments that :class:`Field` receives. + """ + + default_error_messages = { + 'invalid': 'Enter a valid IPv4 or IPv6 address.', + } + + cpdef inline deserialize(self,data,dict context): + try: + return ipaddress.ip_address(data) + except (ValueError, TypeError) : + raise self.raise_if_fail('invalid') + +@cython.final +cdef class PasswordField(StrField): + """ + A field that validates input as a password. + + :param kwargs: The same keyword arguments that :class:`Field` receives. + """ + + def __init__(self,**kwargs): + kwargs['write_only'] = True + kwargs['min_length'] = 4 + kwargs['required'] = True + super().__init__(**kwargs) + + +@cython.final +cdef class UUIDField(StrField): + """ + A field that validates input as an UUID. + + :param kwargs: The same keyword arguments that :class:`Field` receives. + """ + + default_error_messages = { + + "invalid": "Not a valid UUID." + } + + cpdef inline deserialize(self,data,dict context): + if data is None: + return None + if isinstance(data, uuid.UUID): + return data + try: + if isinstance(data, bytes) and len(data) == 16: + return uuid.UUID(bytes=data) + elif isinstance(data,int): + return uuid.UUID(int=data) + elif isinstance(data,str): + return uuid.UUID(hex=data) + else: + return uuid.UUID(data) + except : + raise self.raise_if_fail('invalid') + +@cython.final +cdef class SlugField(Field): + """ + Slug field type. + + :param allow_unicode: If True, allow unicode characters in the field. + :param kwargs: The same keyword arguments that :class:`Field` receives. + """ + default_error_messages = { + 'invalid': 'Not a valid slug.', + 'invalid_unicode' : 'Nnot a valid unicode slug.' + } + + def __init__(self, allow_unicode=False, **kwargs): + self.allow_unicode = allow_unicode + super().__init__(**kwargs) + if self.allow_unicode: + validator = RegexValidator(re.compile(r'^[-\w]+\Z', re.UNICODE), message=self.error_messages['invalid_unicode']) + else: + validator = RegexValidator(re.compile(r'^[-a-zA-Z0-9_]+$'), message=self.error_messages['invalid']) + self.validators.append(validator) + + + + + +@cython.final +cdef class IntField(Field): + """ + A field that validates input as an integer. + + :param min_value: The minimum value allowed. + :param max_value : The maximum value allowed. + :param kwargs: The same keyword arguments that :class:`Field` receives. + """ + default_error_messages = { + 'invalid': 'A valid integer is required.', + } + re_decimal = re.compile(r'\.0*\s*$') # allow e.g. '1.0' as an int, but not '1.2' + def __init__(self,**kwargs): + self.max_value = kwargs.pop('max_value', None) + self.min_value = kwargs.pop('min_value', None) + super().__init__(**kwargs) + if self.max_value is not None: + self.validators.append( + MaxValueValidator(self.max_value)) + if self.min_value is not None: + self.validators.append( + MinValueValidator(self.min_value)) + + cpdef inline serialize(self,value,dict context): + return int(value) + + cpdef inline deserialize(self,data,dict context): + try: + data = int(self.re_decimal.sub('', str(data))) + except (ValueError, TypeError): + raise self.raise_if_fail('invalid') + return data + +@cython.final +cdef class FloatField(Field): + """ + A field that validates input as a float. + + :param min_value: The minimum value allowed. + :param max_value: The maximum value allowed. + :param kwargs: The same keyword arguments that :class:`Field` receives. + """ + + default_error_messages = { + 'invalid': 'A valid number is required.', + } + + def __init__(self,**kwargs): + self.max_value = kwargs.pop('max_value', None) + self.min_value = kwargs.pop('min_value', None) + super().__init__(**kwargs) + if self.max_value is not None: + self.validators.append( + MaxValueValidator(self.max_value)) + if self.min_value is not None: + self.validators.append( + MinValueValidator(self.min_value)) + + + cpdef inline serialize(self,value,dict context): + return float(value) + + cpdef inline deserialize(self,data,dict context): + try: + return float(data) + except (TypeError, ValueError): + raise self.raise_if_fail('invalid') + +@cython.final +cdef class DecimalField(Field): + """ + A field that validates input as a Python Decimal. + + :param max_digits(required): Maximum number of digits. + :param deciaml_places(required): Number of decimal places. + :param min_value: The minimum value allowed. + :param max_value: The maximum value allowed. + :param coerce_to_string: If True, values will be converted to strings during serialization. + :param rounding: How to round the value during serialization. + :param kwargs: The same keyword arguments that :class:`Field` receives. + """ + + default_error_messages = { + 'invalid': 'A valid number is required.', + 'max_value': 'Ensure this value is less than or equal to {max_value}.', + 'min_value': 'Ensure this value is greater than or equal to {min_value}.', + 'max_digits': 'Ensure that there are no more than {max_digits} digits in total.', + 'max_decimal_places': 'Ensure that there are no more than {max_decimal_places} decimal places.', + 'max_whole_digits': 'Ensure that there are no more than {max_whole_digits} digits before the decimal point.', + 'max_string_length': 'String value too large.' + } + MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs. + + def __init__(self, max_digits, decimal_places, max_value=None, min_value=None,coerce_to_string=None,rounding=None,**kwargs): + self.max_digits = max_digits + self.decimal_places = decimal_places + self.max_value = max_value + self.min_value = min_value + if self.max_digits is not None and self.decimal_places is not None: + self.max_whole_digits = self.max_digits - self.decimal_places + else: + self.max_whole_digits = None + + if coerce_to_string is None: + self.coerce_to_string = api_settings.COERCE_DECIMAL_TO_STRING + else: + self.coerce_to_string = coerce_to_string + + super().__init__(**kwargs) + + if self.max_value is not None: + self.validators.append( + MaxValueValidator(self.max_value)) + if self.min_value is not None: + self.validators.append( + MinValueValidator(self.min_value)) + + if rounding is not None: + valid_roundings = [v for k, v in vars(decimal).items() if k.startswith('ROUND_')] + assert rounding in valid_roundings, ( + 'Invalid rounding option %s. Valid values for rounding are: %s' % (rounding, valid_roundings)) + self.rounding = rounding + + + cpdef inline serialize(self,value,dict context): + if value is None: + if self.coerce_to_string: + return '' + else: + return None + + if not isinstance(value, decimal.Decimal): + value = decimal.Decimal(str(value).strip()) + + quantized = self.quantize(value) + if not self.coerce_to_string: + return quantized + + return '{:f}'.format(quantized) + + + + cpdef inline deserialize(self,data,dict context): + """ + Validate that the input is a decimal number and return a Decimal + instance. + """ + + data = smart_str(data).strip() + + if data == '' and self.allow_null: + return None + + if len(data) > self.MAX_STRING_LENGTH: + raise self.raise_if_fail('max_string_length') + + try: + value = decimal.Decimal(data) + except decimal.DecimalException: + raise self.raise_if_fail('invalid') + + if value.is_nan(): + raise self.raise_if_fail('invalid') + + if value in (decimal.Decimal('Inf'), decimal.Decimal('-Inf')): + raise self.raise_if_fail('invalid') + + return self.quantize(self.validate_precision(value)) + + cdef inline validate_precision(self, value): + """ + Ensure that there are no more than max_digits in the number, and no + more than decimal_places digits after the decimal point. + Override this method to disable the precision validation for input + values or to enhance it in any way you need to. + """ + sign, digittuple, exponent = value.as_tuple() + + if exponent >= 0: + # 1234500.0 + total_digits = len(digittuple) + exponent + whole_digits = total_digits + decimal_places = 0 + elif len(digittuple) > abs(exponent): + # 123.45 + total_digits = len(digittuple) + whole_digits = total_digits - abs(exponent) + decimal_places = abs(exponent) + else: + # 0.001234 + total_digits = abs(exponent) + whole_digits = 0 + decimal_places = total_digits + + if self.max_digits is not None and total_digits > self.max_digits: + raise self.raise_if_fail('max_digits', max_digits=self.max_digits) + if self.decimal_places is not None and decimal_places > self.decimal_places: + raise self.raise_if_fail('max_decimal_places', max_decimal_places=self.decimal_places) + if self.max_whole_digits is not None and whole_digits > self.max_whole_digits: + raise self.raise_if_fail('max_whole_digits', max_whole_digits=self.max_whole_digits) + + return value + + cdef inline quantize(self, value): + """ + Quantize the decimal value to the configured precision. + """ + if self.decimal_places is None: + return value + + context = decimal.getcontext().copy() + if self.max_digits is not None: + context.prec = self.max_digits + return value.quantize( + decimal.Decimal('.1') ** self.decimal_places, + rounding=self.rounding, + context=context + ) + + +@cython.final +cdef class BoolField(Field): + """ + Boolean field type. + + :param kwargs: The same keyword arguments that :class:`Field` receives. + """ + default_error_messages = { + 'invalid': 'Not a valid boolean.' + } + _initial = False + coerce_values = { + "true": True, + "True": True, + "TRUE": True, + 't' : True, + "T" : True, + "on": True, + "1": True, + True : True, + 1: True, + "off": False, + "f" : False, + "F" : False, + "false": False, + "False":False, + "FALSE" : False, + "0": False, + "": False, + False : False, + 0: False, + } + coerce_null_values = {"", "null","Null","NULL","none","None","NONE"} + + cpdef inline serialize(self,value,dict context): + try: + if self.allow_null and value in self.coerce_null_values: + return None + value = self.coerce_values[value] + except (KeyError,TypeError): + pass + return bool(value) + + cpdef inline deserialize(self,data,dict context): + try: + if self.allow_null and data in self.coerce_null_values: + return None + data = self.coerce_values[data] + except (KeyError,TypeError): + raise self.raise_if_fail('invalid') + return data + +cdef class ChoiceField(Field): + """ + Choice field type. + + :param choices(required): A list of valid choices. + :param kwargs: The same keyword arguments that :class:`Field` receives. + """ + default_error_messages = { + 'invalid_choice': '"{input}" is not a valid choice.' + } + + def __init__(self, choices, **kwargs): + + pairs = [ + isinstance(item, (list, tuple)) and len(item) == 2 + for item in choices + ] + if all(pairs): + self.choices = dict([(key, display_value) for key, display_value in choices]) + else: + self.choices = dict([(item, item) for item in choices]) + + self.choice_strings_to_values = dict([ + (str(key), key) for key in self.choices.keys() + ]) + + self.choice_strings_to_display = dict([ + (str(key), value) for key,value in self.choices.items() + ]) + + + self.allow_blank = kwargs.pop('allow_blank', False) + super().__init__(**kwargs) + + + + cpdef serialize(self,value,dict context): + if value in ('', None): + return value + + return { + 'value': self.choice_strings_to_values.get(str(value), value), + 'display': self.choice_strings_to_display.get(str(value), value), + } + + + cpdef deserialize(self,data,dict context) : + if data == '' and self.allow_blank: + return '' + try: + return self.choice_strings_to_values[str(data)] + except: + raise self.raise_if_fail('invalid_choice', input=data) + + +@cython.final +cdef class MultipleChoiceField(ChoiceField): + """ + Multiple choice field type. + + :param allow_empty: If True, allow the user to leave the field blank. + :param kwargs: The same keyword arguments that :class:`Field` and class:`ChoiceField` receives. + """ + default_error_messages = { + 'not_a_list': 'Expected a list of items but got type "{input_type}".', + 'empty': 'This selection may not be empty.', + 'invalid_choice': '"{input}" is not a valid choice.' + } + + def __init__(self, **kwargs): + self.allow_empty = kwargs.pop('allow_empty', True) + super().__init__(**kwargs) + + + cpdef inline serialize(self,value,dict context): + return { + self.choice_strings_to_values.get(str(item), item) for item in value + } + + cpdef inline deserialize(self,data,dict context) : + if isinstance(data, str) or not hasattr(data, '__iter__'): + raise self.raise_if_fail('not_a_list', input_type=type(data).__name__) + if not self.allow_empty and len(data) == 0: + raise self.raise_if_fail('empty') + + new_data = set() + for item in data : + if item == '' and self.allow_blank: + return '' + try: + new_data.add(self.choice_strings_to_values[str(item)]) + except: + raise self.raise_if_fail('invalid_choice', input=item) + return new_data + + + + +@cython.final +cdef class DateTimeField(Field): + """ + A field that (de)serializes to a :class:`datetime.datetime` object. + + :param format: The format to use when serializing/deserializing. + :param input_formats: A list of formats to check for when deserializing input. + :param default_timezone: The timezone to use when creating datetime instances. + :param kwargs: The same keyword arguments that :class:`Field` receives. + """ + default_error_messages = { + 'invalid': 'Not a valid datetime.', + 'date': 'Expected a datetime but got a date.', + 'make_aware':'Invalid datetime for the timezone "{timezone}".', + 'overflow': 'Datetime value out of range.' + } + datetime_parser = datetime.datetime.strptime + + def __init__(self, format=NO_DEFAULT,input_formats=None,default_timezone=None,**kwargs): + if format is NO_DEFAULT: + self.format = api_settings.DATETIME_FORMAT + else: + self.format = format + if input_formats is None: + self.input_formats = api_settings.DATETIME_INPUT_FORMATS + else: + self.input_formats = input_formats + + if default_timezone is None: + self.timezone = self.get_default_timezone() + else: + self.timezone = default_timezone + + super().__init__(**kwargs) + + cpdef inline get_default_timezone(self): + return timezone.get_current_timezone() if settings.USE_TZ else None + + + cpdef inline enforce_timezone(self, value): + """ + When `self.default_timezone` is `None`, always return naive datetimes. + When `self.default_timezone` is not `None`, always return aware datetimes. + """ + if self.timezone is not None: + if timezone.is_aware(value): + try: + return value.astimezone(self.timezone) + except OverflowError: + raise self.raise_if_fail('overflow') + try: + return timezone.make_aware(value, self.timezone) + except InvalidTimeError: + raise self.raise_if_fail('make_aware', timezone=self.timezone) + elif (self.timezone is None) and timezone.is_aware(value): + return timezone.make_naive(value, utc) + return value + + + cpdef inline serialize(self,value,dict context): + if not value: + return None + + if self.format is None or isinstance(value, str): + return value + + value = self.enforce_timezone(value) + + if self.format.lower() == ISO_8601: + value = value.isoformat() + if value.endswith('+00:00'): + value = value[:-6] + 'Z' + return value + return value.strftime(self.format) + + + cpdef inline deserialize(self,data,dict context) : + + + if isinstance(data, datetime.date) and not isinstance(data, datetime.datetime): + raise self.raise_if_fail('date') + + if isinstance(data, datetime.datetime): + return self.enforce_timezone(data) + + for input_format in self.input_formats: + if input_format.lower() == ISO_8601: + try: + parsed = parse_datetime(data) + if parsed is not None: + return self.enforce_timezone(parsed) + except (ValueError, TypeError): + pass + else: + try: + parsed = self.datetime_parser(data, input_format) + return self.enforce_timezone(parsed) + except (ValueError, TypeError): + pass + + raise self.raise_if_fail('invalid') + +@cython.final +cdef class DateField(Field): + """ + A field that (de)serializes to a :class:`datetime.date` object. + + :param format: Either ``"%Y-%m-%d"`` or ``"%m/%d/%Y"``, or a custom + string of the format passed to ``strftime``. + :param input_formats: A list of strings, or a custom list of input formats + to be used to parse the input date string. + :param kwargs: The same keyword arguments that :class:`Field` receives. + """ + + default_error_messages = { + 'invalid': 'Not a valid date.', + 'datetime': 'Expected a date but got a datetime.', + } + + def __init__(self, format=None,input_formats=None,**kwargs): + if format is None: + self.format = api_settings.DATE_FORMAT + else: + self.format = format + + if input_formats is None: + self.input_formats = api_settings.DATE_INPUT_FORMATS + else: + self.input_formats = input_formats + + super().__init__(**kwargs) + + + cpdef inline serialize(self, value,dict context): + if not value: + return None + if isinstance(value, str): + return value + assert not isinstance(value, datetime.datetime), ( + 'Expected a `date`, but got a `datetime`. Refusing to coerce, ' + 'as this may mean losing timezone information. Use a custom ' + 'read-only field and deal with timezone issues explicitly.' + ) + + if self.format.lower() == ISO_8601: + return value.isoformat() + + return value.strftime(self.format) + + + cpdef inline deserialize(self, data,dict context) : + + if isinstance(data, datetime.datetime): + raise self.raise_if_fail('datetime') + + if isinstance(data, datetime.date): + return data + + for input_format in self.input_formats: + if input_format.lower() == ISO_8601: + try: + parsed = parse_date(data) + except (ValueError, TypeError): + pass + else: + if parsed is not None: + return parsed + else: + try: + parsed = self.datetime_parser(data, input_format) + except (ValueError, TypeError): + pass + else: + return parsed.date() + + raise self.raise_if_fail('invalid') + + +@cython.final +cdef class TimeField(Field): + """ + A field that (de)serializes to a :class:`datetime.time` object. + + :param format: Either "time" or "datetime" for serialization to display + time in either 24 hour or 12 hour+minute+second format. + :param input_formats: Optional list of formats to also be accepted. + :param kwargs: The same keyword arguments that :class:`Field` receives. + """ + default_error_messages = { + 'invalid': 'Not a valid time.', + } + datetime_parser = datetime.datetime.strptime + + def __init__(self, format=NO_DEFAULT,input_formats=None, **kwargs): + if format is NO_DEFAULT: + self.format = api_settings.DATETIME_FORMAT + else: + self.format = format + if input_formats is None: + self.input_formats = api_settings.TIME_INPUT_FORMATS + else: + self.input_formats = input_formats + + super().__init__(**kwargs) + + + cpdef inline serialize(self, value,dict context): + if value in (None, ''): + return None + + + if self.format is None or isinstance(value, str): + return value + + assert not isinstance(value, datetime.datetime), ( + 'Expected a `time`, but got a `datetime`. Refusing to coerce, ' + 'as this may mean losing timezone information. Use a custom ' + 'read-only field and deal with timezone issues explicitly.' + ) + + if self.format.lower() == ISO_8601: + return value.isoformat() + return value.strftime(self.format) + + + + cpdef inline deserialize(self, data, dict context): + + if isinstance(data, datetime.time): + return data + + for input_format in self.input_formats: + if input_format.lower() == ISO_8601: + try: + parsed = parse_time(data) + except (ValueError, TypeError): + pass + else: + if parsed is not None: + return parsed + else: + try: + parsed = self.datetime_parser(data, input_format) + except (ValueError, TypeError): + pass + else: + return parsed.time() + + raise self.raise_if_fail('invalid') + + + +@cython.final +cdef class FileField(Field): + """ + A file field. + + :param max_length: The maximum file size. + :param alow_empty_file: Whether to allow uploading empty files. + :param kwargs: The same keyword arguments that :class:`Field` receives. + """ + default_error_messages = { + 'required': 'No file was submitted.', + 'invalid': 'The submitted data was not a file. Check the encoding type on the form.', + 'no_name': 'No filename could be determined.', + 'empty': 'The submitted file is empty.', + 'max_length': 'Ensure this filename has at most {max_length} characters (it has {length}).', + } + + def __init__(self,**kwargs): + self.max_length = kwargs.pop('max_length', None) + self.allow_empty_file = kwargs.pop('allow_empty_file', False) + super().__init__(**kwargs) + + + cpdef inline serialize(self,value,dict context): + if not value : + return None + try: + request = context.get('request', None) + return request.build_absolute_uri(value.url) + except: + return value.url + + + cpdef inline deserialize(self,value,dict context): + try: + file_name = value.name + file_size = value.size + except AttributeError: + raise self.raise_if_fail('invalid') + if not file_name: + raise self.raise_if_fail('no_name') + if not self.allow_empty_file and not file_size: + raise self.raise_if_fail('empty') + if self.max_length and len(file_name) > self.max_length: + raise self.raise_if_fail('max_length',max_length=self.max_length, length=len(file_name)) + + return value + +@cython.final +cdef class ArrayField(Field): + """ + An Array Field. + + :param child: The field to validate the array elements. + :param allow_empty: Whether the array can be empty. + :param max_items: The maximum number of items allowed in the array. + :param min_items: The minimum number of items required in the array. + :param exac_items: The exact number of items required in the array. + :param kwargs: The same keyword arguments that :class:`Field` receives. + """ + + default_error_messages = { + 'not_a_list': 'Expected a list of items but got type "{input_type}".', + 'empty': 'This list may not be empty.', + 'exact_items': 'Must have {exact_items} items.', + + } + _initial = [] + + def __init__(self,child=None,**kwargs): + self.child = kwargs.pop('child', copy.deepcopy(child)) + self.allow_empty = kwargs.pop('allow_empty', True) + self.max_items = kwargs.pop('max_items', None) + self.min_items = kwargs.pop('min_items', None) + self.exact_items = kwargs.pop('exact_items', None) + super().__init__(**kwargs) + if self.child is not None: + self.child.bind(field_name='', root=self) + if self.max_items is not None: + self.validators.append(MaxLengthValidator(self.max_items,message=f'Must have no more than {self.max_items} items.')) + if self.min_items is not None: + self.validators.append(MinLengthValidator(self.min_items,message=f'Must have at least {self.min_items} items.')) + + if self.exact_items is not None: + self.min_items = self.exact_items + self.max_items = self.exact_items + + + cpdef inline serialize(self,data ,dict context): + """ + List of object instances -> List of dicts of primitive datatypes. + """ + return [ self.child.serialize(item,context) if item is not None else None for item in data] + + cpdef inline deserialize(self,data,dict context): + """ + List of dicts of native values <- List of dicts of primitive datatypes. + """ + if isinstance(data, (str, dict)) or not hasattr(data, '__iter__'): + raise self.raise_if_fail('not_a_list', input_type=type(data).__name__) + if not self.allow_empty and len(data) == 0: + raise self.raise_if_fail('empty') + if ( + self.min_items is not None + and self.min_items == self.max_items + and len(data) != self.min_items + ): + raise self.raise_if_fail("exact_items",exact_items=self.exact_items) + return self.run_child_validation(data,context) + + cpdef inline run_child_validation(self,data,dict context): + cdef list result = [] + cdef dict errors = {} + cdef int idx = 0 + for item in data : + try: + result.append(self.child.run_validation(item,context)) + except ValidationError as e: + errors[idx] = e.detail + idx +=1 + + if not errors: + return result + raise ValidationError(errors) + +@cython.final +cdef class DictField(Field): + """A Dict Field. + + :param child: The field to validate the dict values. + :param allow_empty: Whether the dict can be empty. + :param kwargs: The same keyword arguments that :class:`Field` receives. + """ + default_error_messages = { + 'not_a_dict': 'Expected a dict of items but got type "{input_type}".', + 'empty': 'This dict may not be empty.', + } + _initial = {} + + def __init__(self,child=None,**kwargs): + self.child = kwargs.pop('child', copy.deepcopy(child)) + self.allow_empty = kwargs.pop('allow_empty', True) + super().__init__(**kwargs) + if self.child is not None: + self.child.bind(field_name='', root=self) + + cpdef inline serialize(self, data,dict context): + """ + List of object instances -> List of dicts of primitive datatypes. + """ + return { + str(key) : self.child.serialize(value,context) if value is not None else None for key,value in data.items() + } + + cpdef inline deserialize(self,data,dict context): + """ + List of dicts of native values <- List of dicts of primitive datatypes. + """ + if not isinstance(data,dict): + raise self.raise_if_fail('not_a_dict', input_type=type(data).__name__) + if not self.allow_empty and len(data) == 0: + raise self.raise_if_fail('empty') + + return self.run_child_validation(data,context) + + cpdef inline run_child_validation(self,dict data,dict context): + cdef dict result = {} + cdef dict errors = {} + for key,value in data.items(): + try: + result[str(key)] = self.child.run_validation(value,context) + except ValidationError as e: + errors[key] = e.detail + + if not errors: + return result + raise ValidationError(errors) + +@cython.final +cdef class JSONField(Field): + """A JSON Field. + + :param binary: Whether to load/dump JSON as binary data. + :param encoder: The JSON encoder class to use. + :param decoder: The JSON decoder class to use. + :param kwargs: The same keyword arguments that :class:`Field` receives. + """ + default_error_messages = { + 'invalid': 'Not a valid JSON.' + } + + def __init__(self, **kwargs): + self.binary = kwargs.pop('binary', False) + self.encoder = kwargs.pop('encoder', None) + self.decoder = kwargs.pop('decoder', None) + super().__init__(**kwargs) + + cpdef inline serialize(self,value,dict context): + if self.binary: + value = json.dumps(value,cls=self.encoder) + value = value.encode() + return value + + cpdef inline deserialize(self,data,dict context): + try: + if self.binary: + if isinstance(data, bytes): + data = data.decode() + return json.loads(data,cls=self.decoder) + else: + json.dumps(data,cls=self.encoder) + except (TypeError, ValueError) as e: + raise self.raise_if_fail('invalid') + return data + + +@cython.final +cdef class RelatedField(Field): + """A Related Field. + + :param queryset: The queryset to use for getting the instance from the given value. + """ + default_error_messages = { + 'does_not_exist': 'Invalid pk "{pk_value}" - object does not exist.', + 'incorrect_type': 'Incorrect type. Expected pk value, received {data_type}.', + + } + def __init__(self,**kwargs): + self.queryset = kwargs.pop('queryset', None) + super().__init__(**kwargs) + + cpdef inline deserialize(self,data,dict context): + try: + if isinstance(data, bool): + raise TypeError + data = self.queryset.get(pk=data) + except ObjectDoesNotExist: + raise self.raise_if_fail('does_not_exist', pk_value=data) + except (TypeError, ValueError): + raise self.raise_if_fail('incorrect_type', data_type=type(data).__name__) + return data + + +@cython.final +cdef class ManyRelatedField(Field): + """ + A field used to represent a to-many relationship. + + :param child_relation: The model field that is the reverse of the relation. + :param allow_empty: Whether the list can be empty. + :param kwargs: The same keyword arguments that :class:`Field` receives. + """ + default_error_messages = { + 'not_a_list': 'Expected a list of items but got type "{input_type}".', + 'empty': 'This list may not be empty.' + } + def __init__(self,**kwargs): + self.child_relation = kwargs.pop('child_relation', None) + self.allow_empty = kwargs.pop('allow_empty', True) + super().__init__(**kwargs) + + cpdef inline serialize(self,value,dict context): + + value = value.all() if hasattr(value, 'all') else value + return [ + item.pk for item in value + ] + + cpdef inline deserialize(self, data,dict context): + if isinstance(data, str) or not hasattr(data, '__iter__'): + raise self.raise_if_fail('not_a_list', input_type=type(data).__name__) + if not self.allow_empty and len(data) == 0: + raise self.raise_if_fail('empty') + return [ + self.child_relation.deserialize(item,context) + for item in data + ] + + +@cython.final +cdef class ConstantField(Field): + """ + A field that always outputs a fixed value. + + :param constant: The value to return. + :param kwargs: The same keyword arguments that :class:`Field` receives. + """ + default_error_messages = { + 'constant': 'Must be "{constant}".', + 'None': 'Must be None.' + } + def __init__(self,constant,**kwargs): + self.constant = constant + super().__init__(**kwargs) + assert "allow_null" not in kwargs + + cpdef inline deserialize(self,data,dict context): + if data != self.constant: + if self.constant is None: + raise self.raise_if_fail("None") + raise self.raise_if_fail("constant",constant=self.constant) + return data + + +@cython.final +cdef class RecursiveField(Field): + """ + A field that recursively validates its data. + + :param many: Whether the field is a collection of items. + :param context: The context passed to the field's :meth:`run_validation`. + :param only: A tuple or list of field names to include. + :param exclude: A tuple or list of field names to exclude. + :param kwargs: The same keyword arguments that :class:`Field` receives. + """ + def __init__(self,**kwargs): + self.many = kwargs.pop('many',False) + self.context = kwargs.pop('context',{}) + self.only = kwargs.pop('only',None) + self.exclude = kwargs.pop('exclude',None) + if self.only is not None and self.exclude is not None : + raise OnlyAndExcludeError('You should use either "only" or "exclude"') + if self.only is not None and not is_collection(self.only): + raise StringNotCollectionError('"only" should be a list of strings') + if self.exclude is not None and not is_collection(self.exclude): + raise StringNotCollectionError('"exclude" should be a list of strings') + super().__init__(**kwargs) + + cpdef inline serialize(self,value,dict context): + if self.only : + serializer = self.root.__class__(value,many=self.many,only=self.only,context=self.context) + elif self.exclude : + serializer = self.root.__class__(value,many=self.many,exclude=self.exclude,context=self.context) + else: + serializer = self.root.__class__(value,many=self.many,context=self.context) + return serializer.data + + +@cython.final +cdef class MethodField(Field): + """ + A field that calls a method on the serializer instead of simply returning a value. + + :param method_name: The name of the method to call. + :param kwargs: The same keyword arguments that :class:`Field` receives. + """ + is_method_field = True + + def __init__(self,method_name=None,**kwargs): + kwargs['read_only']= True + kwargs['required'] = False + self.method_name = method_name + super().__init__(**kwargs) + + cpdef inline method_getter(self,field_name, root) : + if self.method_name is None: + self.method_name = 'get_{0}'.format(field_name) + return getattr(root, self.method_name) \ No newline at end of file diff --git a/drf_turbo/meta.py b/drf_turbo/meta.py new file mode 100755 index 0000000..e028249 --- /dev/null +++ b/drf_turbo/meta.py @@ -0,0 +1,191 @@ +from django.db import models +from drf_turbo.fields import ( + Field,IntField,BoolField,StrField,DateField,TimeField,DateTimeField,EmailField,FileField,FloatField,URLField,IPField,UUIDField,RelatedField,ManyRelatedField,ChoiceField,DecimalField,ArrayField,JSONField +) +from django.contrib.postgres import fields as postgres_fields + +class SerializerMetaclass(type): + @classmethod + def _get_fields(cls,bases,attrs): + fields = [(field_name, attrs.pop(field_name)) + for field_name, obj in list(attrs.items()) + if isinstance(obj, Field)] + for base in reversed(bases): + if hasattr(base, '_fields'): + fields = list( + getattr(base,'_fields').items() + ) + fields + return dict(fields) + + + + def __new__(cls, name, bases, attrs): + attrs['_fields'] = cls._get_fields(bases, attrs) + return super().__new__(cls, name, bases, attrs) + + +class ModelSerializerMetaclass(SerializerMetaclass): + + TYPE_MAPPING = { + models.AutoField: IntField, + models.BigIntegerField: IntField, + models.BooleanField: BoolField, + models.CharField: StrField, + models.DateField: DateField, + models.TimeField : TimeField, + models.DateTimeField: DateTimeField, + models.EmailField: EmailField, + models.FileField: FileField, + models.FloatField: FloatField, + models.ImageField: FileField, + models.IntegerField: IntField, + models.NullBooleanField: BoolField, + models.PositiveIntegerField: IntField, + models.PositiveSmallIntegerField: IntField, + models.SmallIntegerField: IntField, + models.TextField: StrField, + models.URLField: URLField, + models.GenericIPAddressField: IPField, + models.JSONField : JSONField, + models.UUIDField : UUIDField, + postgres_fields.ArrayField: ArrayField, + postgres_fields.JSONField : JSONField, + + + } + + @staticmethod + def _get_implicit_fields(model_fields,fields,exclude): + + if fields == '__all__': + fields = model_fields + elif fields and isinstance(fields,(list,tuple)): + fields = [ + field for field in model_fields + if field.name in fields + ] + + elif not fields and exclude: + fields = [ + field for field in model_fields + if field.name not in exclude + ] + # this implicitly handles the case when `fields` is set and `exclude` + # isn't. Then all fields declared will be returned without any + # modification. + return fields + + @staticmethod + def _filter_fields(cls,declared_fields,explicit_fields,implicit_fields): + for field in implicit_fields : + if field.name not in explicit_fields: + klass = field.__class__ + if issubclass(klass,(models.ForeignKey,models.OneToOneField)): + field_obj = RelatedField() + field_obj.queryset = field.related_model.objects + elif issubclass(klass,models.ManyToManyField) : + field_obj = ManyRelatedField() + child_obj = RelatedField() + child_obj.queryset = field.related_model.objects + field_obj.child_relation = child_obj + elif issubclass(klass,models.DecimalField) : + field_obj = DecimalField(max_digits=field.max_digits,decimal_places=field.decimal_places) + else: + try: + field.get_choices() + field_obj = ChoiceField(choices=field.choices) + except : + field_obj = cls.TYPE_MAPPING.get(klass,Field)() + + field_obj.attr = field.name + field_obj.help_text = str(field.help_text) + field_obj.validators = field.validators + if issubclass(klass, models.AutoField) or not field.editable : + field_obj.read_only = True + if field.has_default() or field.blank or field.null: + field_obj.required = False + if field.get_internal_type() == 'TextField': + field_obj.default_value = '' + if field.has_default(): + field_obj.default_value = field.default.value if hasattr(field.default,'value') else field.default + if getattr(field, 'auto_now', False): + field_obj.default_value = field.auto_now + if getattr(field, 'auto_now_add', False): + field_obj.default_value = field.auto_now_add + + declared_fields[field.name] = field_obj + + return declared_fields + + + @staticmethod + def _read_only_fields(read_only_fields,declared_fields): + if read_only_fields: + if not isinstance(read_only_fields, (list, tuple)): + raise TypeError( + 'The `read_only_fields` option must be a list or tuple. ' + 'Got %s.' % type(read_only_fields).__name__ + ) + for field in read_only_fields: + declared_fields[field].read_only = True + declared_fields[field].required = False + + + @staticmethod + def _write_only_fields(write_only_fields,declared_fields): + if write_only_fields: + if not isinstance(write_only_fields, (list, tuple)): + raise TypeError( + 'The `write_only_fields` option must be a list or tuple. ' + 'Got %s.' % type(write_only_fields).__name__ + ) + for field in write_only_fields: + declared_fields[field].write_only = True + + + + def __new__(cls, name, bases, attrs): + klass = super().__new__(cls, name, bases, attrs) + declared_fields = attrs['_fields'] + meta = getattr(klass, 'Meta', None) + if meta: + model = getattr(meta, 'model', None) + fields = getattr(meta, 'fields', None) + exclude = getattr(meta, 'exclude', None) + read_only_fields = getattr(meta,'read_only_fields',None) + write_only_fields = getattr(meta,'write_only_fields',None) + if not model: + raise RuntimeError( + 'If you specifiy a Meta class, you need to atleast specify a model' + ) + if not fields and not exclude: + raise RuntimeError( + 'You need to specifiy either `fields` or `exclude` in Meta' + ) + if fields and exclude: + raise RuntimeError( + '`fields` and `exclude` prohibit each other.' + ) + + + if hasattr(model, "_meta"): + # Django models + model_fields = [field for field in model._meta.fields] + many_to_many_fields = [field for field in model._meta.many_to_many] + model_fields.extend(many_to_many_fields) + implicit_fields = cls._get_implicit_fields( + model_fields, fields, exclude + ) + explicit_fields = declared_fields.keys() + declared_fields = cls._filter_fields(cls, + declared_fields, explicit_fields, implicit_fields + ) + cls._read_only_fields(read_only_fields,declared_fields) + cls._write_only_fields(write_only_fields,declared_fields) + + klass._model_fields = model_fields + klass._fields = declared_fields + + + return klass + diff --git a/drf_turbo/openapi.py b/drf_turbo/openapi.py new file mode 100755 index 0000000..1894bcd --- /dev/null +++ b/drf_turbo/openapi.py @@ -0,0 +1,921 @@ +import re +import warnings +from collections import OrderedDict +from decimal import Decimal +from operator import attrgetter +from urllib.parse import urljoin + +from django.core.validators import ( + DecimalValidator, EmailValidator, MaxLengthValidator, MaxValueValidator, + MinLengthValidator, MinValueValidator, RegexValidator, URLValidator +) +from django.db import models +from django.utils.encoding import force_str + +from rest_framework import ( + RemovedInDRF314Warning, exceptions, renderers, serializers +) +import drf_turbo as dt +from rest_framework.authentication import ( + BasicAuthentication,SessionAuthentication +) +from rest_framework_simplejwt.authentication import JWTAuthentication + +from rest_framework.compat import uritemplate +from rest_framework.fields import _UnvalidatedField, empty +from rest_framework.settings import api_settings + +from rest_framework.schemas.generators import BaseSchemaGenerator +from rest_framework.schemas.inspectors import ViewInspector +from rest_framework.schemas.utils import get_pk_description, is_list_view + +NO_DEFAULT = object() + +class SchemaGenerator(BaseSchemaGenerator): + + def get_info(self): + # Title and version are required by openapi specification 3.x + info = { + 'title': self.title or '', + 'version': self.version or '' + } + + if self.description is not None: + info['description'] = self.description + + return info + + def check_duplicate_operation_id(self, paths): + ids = {} + for route in paths: + for method in paths[route]: + if 'operationId' not in paths[route][method]: + continue + operation_id = paths[route][method]['operationId'] + if operation_id in ids: + warnings.warn( + 'You have a duplicated operationId in your OpenAPI schema: {operation_id}\n' + '\tRoute: {route1}, Method: {method1}\n' + '\tRoute: {route2}, Method: {method2}\n' + '\tAn operationId has to be unique across your schema. Your schema may not work in other tools.' + .format( + route1=ids[operation_id]['route'], + method1=ids[operation_id]['method'], + route2=route, + method2=method, + operation_id=operation_id + ) + ) + ids[operation_id] = { + 'route': route, + 'method': method + } + + def get_schema(self, request=None, public=False): + """ + Generate a OpenAPI schema. + """ + self._initialise_endpoints() + components_schemas = {} + + # Iterate endpoints generating per method path operations. + paths = {} + _, view_endpoints = self._get_paths_and_endpoints(None if public else request) + for path, method, view in view_endpoints: + if not self.has_view_permissions(path, method, view): + continue + + operation = view.schema.get_operation(path, method) + components = view.schema.get_components(path, method) + for k in components.keys(): + if k not in components_schemas: + continue + if components_schemas[k] == components[k]: + continue + warnings.warn('Schema component "{}" has been overriden with a different value.'.format(k)) + + components_schemas.update(components) + + # Normalise path for any provided mount url. + if path.startswith('/'): + path = path[1:] + path = urljoin(self.url or '/', path) + + paths.setdefault(path, {}) + paths[path][method.lower()] = operation + + self.check_duplicate_operation_id(paths) + + # Compile final schema. + schema = { + 'openapi': '3.0.2', + 'info': self.get_info(), + 'paths': paths, + } + + if len(components_schemas) > 0: + schema['components'] = { + 'schemas': components_schemas + } + + return schema + def get_info(self): + # Title and version are required by openapi specification 3.x + info = { + 'title': self.title or '', + 'version': self.version or '' + } + + if self.description is not None: + info['description'] = self.description + + return info + + def check_duplicate_operation_id(self, paths): + ids = {} + for route in paths: + for method in paths[route]: + if 'operationId' not in paths[route][method]: + continue + operation_id = paths[route][method]['operationId'] + if operation_id in ids: + warnings.warn( + 'You have a duplicated operationId in your OpenAPI schema: {operation_id}\n' + '\tRoute: {route1}, Method: {method1}\n' + '\tRoute: {route2}, Method: {method2}\n' + '\tAn operationId has to be unique across your schema. Your schema may not work in other tools.' + .format( + route1=ids[operation_id]['route'], + method1=ids[operation_id]['method'], + route2=route, + method2=method, + operation_id=operation_id + ) + ) + ids[operation_id] = { + 'route': route, + 'method': method + } + + def get_schema(self, request=None, public=False): + """ + Generate a OpenAPI schema. + """ + self._initialise_endpoints() + components_schemas = {} + + # Iterate endpoints generating per method path operations. + paths = {} + _, view_endpoints = self._get_paths_and_endpoints(None if public else request) + for path, method, view in view_endpoints: + if not self.has_view_permissions(path, method, view): + continue + + operation = view.schema.get_operation(path, method) + components = view.schema.get_components(path, method) + for k in components.keys(): + if k not in components_schemas: + continue + if components_schemas[k] == components[k]: + continue + warnings.warn('Schema component "{}" has been overriden with a different value.'.format(k)) + + components_schemas.update(components) + + # Normalise path for any provided mount url. + if path.startswith('/'): + path = path[1:] + path = urljoin(self.url or '/', path) + + paths.setdefault(path, {}) + paths[path][method.lower()] = operation + + self.check_duplicate_operation_id(paths) + + # Compile final schema. + + schema = { + 'openapi': '3.0.2', + 'info': self.get_info(), + 'paths': paths, + } + + if len(components_schemas) > 0: + schema['components'] = { + 'schemas': components_schemas, + 'securitySchemes': self.get_security_schemes(paths) + } + + return schema + + def get_security_schemes(self, paths): + security_schemes = {} + for path, method in paths.items(): + for _, operation in method.items(): + for security in operation['security']: + name = next(iter(security)) + if name == 'BasicAuth': + security_schemes[name] = { + 'type': 'http', + 'scheme': 'basic', + 'description': 'Basic authentication' + } + elif name == 'cookieAuth': + security_schemes[name] = { + 'type' : 'apiKey', + 'in' : 'cookie', + 'description': 'Session authentication', + } + elif name == 'BearerAuth': + security_schemes[name] = { + 'type': 'http', + 'scheme': 'bearer', + 'description': 'Bearer authentication' + + } + return security_schemes + + +class AutoSchema(ViewInspector): + + def __init__(self, tags=None, operation_id_base=None, component_name=None): + """ + :param operation_id_base: user-defined name in operationId. If empty, it will be deducted from the Model/Serializer/View name. + :param component_name: user-defined component's name. If empty, it will be deducted from the Serializer's class name. + """ + if tags and not all(isinstance(tag, str) for tag in tags): + raise ValueError('tags must be a list or tuple of string.') + self._tags = tags + self.operation_id_base = operation_id_base + self.component_name = component_name + super().__init__() + + request_media_types = [] + response_media_types = [] + + method_mapping = { + 'get': 'retrieve', + 'post': 'create', + 'put': 'update', + 'patch': 'partialUpdate', + 'delete': 'destroy', + } + + def get_operation(self, path, method): + operation = {} + + operation['operationId'] = self.get_operation_id(path, method) + operation['description'] = self.get_description(path, method) + + parameters = [] + parameters += self.get_path_parameters(path, method) + parameters += self.get_pagination_parameters(path, method) + parameters += self.get_filter_parameters(path, method) + operation['parameters'] = parameters + + request_body = self.get_request_body(path, method) + if request_body: + operation['requestBody'] = request_body + operation['responses'] = self.get_responses(path, method) + operation['tags'] = self.get_tags(path, method) + operation['security'] = self._get_security(path, method) + + return operation + + + def _get_security(self, path, method): + security = [] + for auth_class in self.view.authentication_classes: + if issubclass(auth_class, BasicAuthentication): + security.append({'BasicAuth': []}) + elif issubclass(auth_class,SessionAuthentication): + security.append({'cookieAuth' : []}) + elif issubclass(auth_class, JWTAuthentication): + security.append({'BearerAuth': []}) + return security + + + def get_component_name(self, serializer): + """ + Compute the component's name from the serializer. + Raise an exception if the serializer's class name is "Serializer" (case-insensitive). + """ + if self.component_name is not None: + return self.component_name + + # use the serializer's class name as the component name. + component_name = serializer.__class__.__name__ + # We remove the "serializer" string from the class name. + pattern = re.compile("serializer", re.IGNORECASE) + component_name = pattern.sub("", component_name) + + if component_name == "": + raise Exception( + '"{}" is an invalid class name for schema generation. ' + 'Serializer\'s class name should be unique and explicit. e.g. "ItemSerializer"' + .format(serializer.__class__.__name__) + ) + + return component_name + + def get_components(self, path, method): + """ + Return components with their properties from the serializer. + """ + + if method.lower() == 'delete': + return {} + + serializer = self.get_serializer(path, method) + if not isinstance(serializer, (dt.Serializer,serializers.Serializer)): + return {} + + component_name = self.get_component_name(serializer) + content = self.map_serializer(serializer) + return {component_name: content} + + def _to_camel_case(self, snake_str): + components = snake_str.split('_') + # We capitalize the first letter of each component except the first one + # with the 'title' method and join them together. + return components[0] + ''.join(x.title() for x in components[1:]) + + def get_operation_id_base(self, path, method, action): + """ + Compute the base part for operation ID from the model, serializer or view name. + """ + model = getattr(getattr(self.view, 'queryset', None), 'model', None) + + if self.operation_id_base is not None: + name = self.operation_id_base + + # Try to deduce the ID from the view's model + elif model is not None: + name = model.__name__ + + # Try with the serializer class name + elif self.get_serializer(path, method) is not None: + name = self.get_serializer(path, method).__class__.__name__ + if name.endswith('Serializer'): + name = name[:-10] + + # Fallback to the view name + else: + name = self.view.__class__.__name__ + if name.endswith('APIView'): + name = name[:-7] + elif name.endswith('View'): + name = name[:-4] + + # Due to camel-casing of classes and `action` being lowercase, apply title in order to find if action truly + # comes at the end of the name + if name.endswith(action.title()): # ListView, UpdateAPIView, ThingDelete ... + name = name[:-len(action)] + + if action == 'list' and not name.endswith('s'): # listThings instead of listThing + name += 's' + + return name + + def get_operation_id(self, path, method): + """ + Compute an operation ID from the view type and get_operation_id_base method. + """ + method_name = getattr(self.view, 'action', method.lower()) + if is_list_view(path, method, self.view): + action = 'list' + elif method_name not in self.method_mapping: + action = self._to_camel_case(method_name) + else: + action = self.method_mapping[method.lower()] + + name = self.get_operation_id_base(path, method, action) + + return action + name + + def get_path_parameters(self, path, method): + """ + Return a list of parameters from templated path variables. + """ + assert uritemplate, '`uritemplate` must be installed for OpenAPI schema support.' + + model = getattr(getattr(self.view, 'queryset', None), 'model', None) + parameters = [] + + for variable in uritemplate.variables(path): + description = '' + if model is not None: # TODO: test this. + # Attempt to infer a field description if possible. + try: + model_field = model._meta.get_field(variable) + except Exception: + model_field = None + + if model_field is not None and model_field.help_text: + description = force_str(model_field.help_text) + elif model_field is not None and model_field.primary_key: + description = get_pk_description(model, model_field) + + parameter = { + "name": variable, + "in": "path", + "required": True, + "description": description, + 'schema': { + 'type': 'string', # TODO: integer, pattern, ... + }, + } + parameters.append(parameter) + + return parameters + + def get_filter_parameters(self, path, method): + if not self.allows_filters(path, method): + return [] + parameters = [] + for filter_backend in self.view.filter_backends: + parameters += filter_backend().get_schema_operation_parameters(self.view) + return parameters + + def allows_filters(self, path, method): + """ + Determine whether to include filter Fields in schema. + + Default implementation looks for ModelViewSet or GenericAPIView + actions/methods that cause filtering on the default implementation. + """ + if getattr(self.view, 'filter_backends', None) is None: + return False + if hasattr(self.view, 'action'): + return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"] + return method.lower() in ["get", "put", "patch", "delete"] + + def get_pagination_parameters(self, path, method): + view = self.view + + if not is_list_view(path, method, view): + return [] + + paginator = self.get_paginator() + if not paginator: + return [] + + return paginator.get_schema_operation_parameters(view) + + def map_choicefield(self, field): + choices = list(OrderedDict.fromkeys(field.choices)) # preserve order and remove duplicates + if all(isinstance(choice, bool) for choice in choices): + type = 'boolean' + elif all(isinstance(choice, int) for choice in choices): + type = 'integer' + elif all(isinstance(choice, (int, float, Decimal)) for choice in choices): # `number` includes `integer` + # Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.21 + type = 'number' + elif all(isinstance(choice, str) for choice in choices): + type = 'string' + else: + type = None + + mapping = { + # The value of `enum` keyword MUST be an array and SHOULD be unique. + # Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.20 + 'enum': choices + } + + # If We figured out `type` then and only then we should set it. It must be a string. + # Ref: https://swagger.io/docs/specification/data-models/data-types/#mixed-type + # It is optional but it can not be null. + # Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.21 + if type: + mapping['type'] = type + return mapping + + def map_field(self, field): + + # Nested Serializers, `many` or not. + if isinstance(field, serializers.ListSerializer): + return { + 'type': 'array', + 'items': self.map_serializer(field.child) + } + if isinstance(field, (dt.Serializer,serializers.Serializer)): + data = self.map_serializer(field) + data['type'] = 'object' + return data + + # Related fields. + if isinstance(field, (dt.ManyRelatedField,serializers.ManyRelatedField)): + return { + 'type': 'array', + 'items': self.map_field(field.child_relation) + } + if isinstance(field, (dt.RelatedField,serializers.PrimaryKeyRelatedField)): + model = getattr(field.queryset, 'model', None) + if model is not None: + model_field = model._meta.pk + if isinstance(model_field, models.AutoField): + return {'type': 'integer'} + + # ChoiceFields (single and multiple). + # Q: + # - Is 'type' required? + # - can we determine the TYPE of a choicefield? + if isinstance(field, (dt.MultipleChoiceField,serializers.MultipleChoiceField)): + return { + 'type': 'array', + 'items': self.map_choicefield(field) + } + + if isinstance(field, (dt.ChoiceField,serializers.ChoiceField)): + return self.map_choicefield(field) + + # ListField. + if isinstance(field, (dt.ArrayField,serializers.ListField)): + mapping = { + 'type': 'array', + 'items': {}, + } + if not isinstance(field.child, _UnvalidatedField): + mapping['items'] = self.map_field(field.child) + return mapping + + # DateField and DateTimeField type is string + if isinstance(field, (dt.DateField,serializers.DateField)): + return { + 'type': 'string', + 'format': 'date', + } + + if isinstance(field, (dt.DateTimeField,serializers.DateTimeField)): + return { + 'type': 'string', + 'format': 'date-time', + } + + # "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification." + # see: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types + # see also: https://swagger.io/docs/specification/data-models/data-types/#string + if isinstance(field, (dt.EmailField,serializers.EmailField)): + return { + 'type': 'string', + 'format': 'email' + } + + if isinstance(field, (dt.URLField,serializers.URLField)): + return { + 'type': 'string', + 'format': 'uri' + } + + if isinstance(field, (dt.UUIDField,serializers.UUIDField)): + return { + 'type': 'string', + 'format': 'uuid' + } + + if isinstance(field, (dt.IPField,serializers.IPAddressField)): + content = { + 'type': 'string', + } + if field.protocol != 'both': + content['format'] = field.protocol + return content + + if isinstance(field, (dt.DecimalField,serializers.DecimalField)): + if getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING): + content = { + 'type': 'string', + 'format': 'decimal', + } + else: + content = { + 'type': 'number' + } + + if field.decimal_places: + content['multipleOf'] = float('.' + (field.decimal_places - 1) * '0' + '1') + if field.max_whole_digits: + content['maximum'] = int(field.max_whole_digits * '9') + 1 + content['minimum'] = -content['maximum'] + self._map_min_max(field, content) + return content + + if isinstance(field, (dt.FloatField,serializers.FloatField)): + content = { + 'type': 'number', + } + self._map_min_max(field, content) + return content + + if isinstance(field, (dt.IntField,serializers.IntegerField)): + content = { + 'type': 'integer' + } + self._map_min_max(field, content) + # 2147483647 is max for int32_size, so we use int64 for format + if int(content.get('maximum', 0)) > 2147483647 or int(content.get('minimum', 0)) > 2147483647: + content['format'] = 'int64' + return content + + if isinstance(field, (dt.FileField,serializers.FileField)): + return { + 'type': 'string', + 'format': 'binary' + } + + # Simplest cases, default to 'string' type: + FIELD_CLASS_SCHEMA_TYPE = { + serializers.BooleanField : 'boolean', + dt.BoolField: 'boolean', + serializers.JSONField : 'object', + dt.JSONField: 'object', + serializers.DictField: 'object', + dt.DictField : 'object', + serializers.HStoreField: 'object', + } + return {'type': FIELD_CLASS_SCHEMA_TYPE.get(field.__class__, 'string')} + + def _map_min_max(self, field, content): + if field.max_value: + content['maximum'] = field.max_value + if field.min_value: + content['minimum'] = field.min_value + + def map_serializer(self, serializer): + # Assuming we have a valid serializer instance. + required = [] + properties = {} + + for name,field in serializer.fields.items(): + default = field.default_value if isinstance(field,dt.Field) else field.default + if isinstance(field, serializers.HiddenField): + continue + if field.required: + required.append(name) + schema = self.map_field(field) + if field.read_only: + schema['readOnly'] = True + if field.write_only: + schema['writeOnly'] = True + if field.allow_null: + schema['nullable'] = True + if default is not None and default != empty and default != NO_DEFAULT and not callable(default): + schema['default'] = default + if field.help_text: + schema['description'] = str(field.help_text) + self.map_field_validators(field, schema) + + properties[name] = schema + + result = { + 'type': 'object', + 'properties': properties + } + if required: + result['required'] = required + + return result + + def map_field_validators(self, field, schema): + """ + map field validators + """ + for v in field.validators: + # "Formats such as "email", "uuid", and so on, MAY be used even though undefined by this specification." + # https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#data-types + if isinstance(v, EmailValidator): + schema['format'] = 'email' + if isinstance(v, URLValidator): + schema['format'] = 'uri' + if isinstance(v, RegexValidator): + # In Python, the token \Z does what \z does in other engines. + # https://stackoverflow.com/questions/53283160 + schema['pattern'] = v.regex.pattern.replace('\\Z', '\\z') + elif isinstance(v, MaxLengthValidator): + attr_name = 'maxLength' + if isinstance(field, (dt.ArrayField,serializers.ListField)): + attr_name = 'maxItems' + schema[attr_name] = v.limit_value + elif isinstance(v, MinLengthValidator): + attr_name = 'minLength' + if isinstance(field, (dt.ArrayField,serializers.ListField)): + attr_name = 'minItems' + schema[attr_name] = v.limit_value + elif isinstance(v, MaxValueValidator): + schema['maximum'] = v.limit_value + elif isinstance(v, MinValueValidator): + schema['minimum'] = v.limit_value + elif isinstance(v, DecimalValidator) and \ + not getattr(field, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING): + if v.decimal_places: + schema['multipleOf'] = float('.' + (v.decimal_places - 1) * '0' + '1') + if v.max_digits: + digits = v.max_digits + if v.decimal_places is not None and v.decimal_places > 0: + digits -= v.decimal_places + schema['maximum'] = int(digits * '9') + 1 + schema['minimum'] = -schema['maximum'] + + def get_paginator(self): + pagination_class = getattr(self.view, 'pagination_class', None) + if pagination_class: + return pagination_class() + return None + + def map_parsers(self, path, method): + return list(map(attrgetter('media_type'), self.view.parser_classes)) + + def map_renderers(self, path, method): + media_types = [] + for renderer in self.view.renderer_classes: + # BrowsableAPIRenderer not relevant to OpenAPI spec + if issubclass(renderer, renderers.BrowsableAPIRenderer): + continue + media_types.append(renderer.media_type) + return media_types + + def get_serializer(self, path, method): + view = self.view + if not hasattr(view, 'get_serializer'): + return None + + try: + return view.get_serializer() + except exceptions.APIException: + warnings.warn('{}.get_serializer() raised an exception during ' + 'schema generation. Serializer fields will not be ' + 'generated for {} {}.' + .format(view.__class__.__name__, method, path)) + return None + + def _get_reference(self, serializer): + return {'$ref': '#/components/schemas/{}'.format(self.get_component_name(serializer))} + + def get_request_body(self, path, method): + if method not in ('PUT', 'PATCH', 'POST'): + return {} + + self.request_media_types = self.map_parsers(path, method) + + serializer = self.get_serializer(path, method) + if not isinstance(serializer, (dt.Serializer,serializers.Serializer)): + item_schema = {} + else: + item_schema = self._get_reference(serializer) + + return { + 'content': { + ct: {'schema': item_schema} + for ct in self.request_media_types + } + } + + def get_responses(self, path, method): + if method == 'DELETE': + return { + '204': { + 'description': '' + } + } + + self.response_media_types = self.map_renderers(path, method) + + serializer = self.get_serializer(path, method) + + if not isinstance(serializer, (dt.Serializer,serializers.Serializer)): + item_schema = {} + else: + item_schema = self._get_reference(serializer) + + if is_list_view(path, method, self.view): + response_schema = { + 'type': 'array', + 'items': item_schema, + } + paginator = self.get_paginator() + if paginator: + response_schema = paginator.get_paginated_response_schema(response_schema) + else: + response_schema = item_schema + status_code = '201' if method == 'POST' else '200' + return { + status_code: { + 'content': { + ct: {'schema': response_schema} + for ct in self.response_media_types + }, + # description is a mandatory property, + # https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#responseObject + # TODO: put something meaningful into it + 'description': "" + } + } + + def get_tags(self, path, method): + # If user have specified tags, use them. + if self._tags: + return self._tags + + # First element of a specific path could be valid tag. This is a fallback solution. + # PUT, PATCH, GET(Retrieve), DELETE: /user_profile/{id}/ tags = [user-profile] + # POST, GET(List): /user_profile/ tags = [user-profile] + if path.startswith('/'): + path = path[1:] + + return [path.split('/')[0].replace('_', '-')] + + def _get_path_parameters(self, path, method): + warnings.warn( + "Method `_get_path_parameters()` has been renamed to `get_path_parameters()`. " + "The old name will be removed in DRF v3.14.", + RemovedInDRF314Warning, stacklevel=2 + ) + return self.get_path_parameters(path, method) + + def _get_filter_parameters(self, path, method): + warnings.warn( + "Method `_get_filter_parameters()` has been renamed to `get_filter_parameters()`. " + "The old name will be removed in DRF v3.14.", + RemovedInDRF314Warning, stacklevel=2 + ) + return self.get_filter_parameters(path, method) + + def _get_responses(self, path, method): + warnings.warn( + "Method `_get_responses()` has been renamed to `get_responses()`. " + "The old name will be removed in DRF v3.14.", + RemovedInDRF314Warning, stacklevel=2 + ) + return self.get_responses(path, method) + + def _get_request_body(self, path, method): + warnings.warn( + "Method `_get_request_body()` has been renamed to `get_request_body()`. " + "The old name will be removed in DRF v3.14.", + RemovedInDRF314Warning, stacklevel=2 + ) + return self.get_request_body(path, method) + + def _get_serializer(self, path, method): + warnings.warn( + "Method `_get_serializer()` has been renamed to `get_serializer()`. " + "The old name will be removed in DRF v3.14.", + RemovedInDRF314Warning, stacklevel=2 + ) + return self.get_serializer(path, method) + + def _get_paginator(self): + warnings.warn( + "Method `_get_paginator()` has been renamed to `get_paginator()`. " + "The old name will be removed in DRF v3.14.", + RemovedInDRF314Warning, stacklevel=2 + ) + return self.get_paginator() + + def _map_field_validators(self, field, schema): + warnings.warn( + "Method `_map_field_validators()` has been renamed to `map_field_validators()`. " + "The old name will be removed in DRF v3.14.", + RemovedInDRF314Warning, stacklevel=2 + ) + return self.map_field_validators(field, schema) + + def _map_serializer(self, serializer): + warnings.warn( + "Method `_map_serializer()` has been renamed to `map_serializer()`. " + "The old name will be removed in DRF v3.14.", + RemovedInDRF314Warning, stacklevel=2 + ) + return self.map_serializer(serializer) + + def _map_field(self, field): + warnings.warn( + "Method `_map_field()` has been renamed to `map_field()`. " + "The old name will be removed in DRF v3.14.", + RemovedInDRF314Warning, stacklevel=2 + ) + return self.map_field(field) + + def _map_choicefield(self, field): + warnings.warn( + "Method `_map_choicefield()` has been renamed to `map_choicefield()`. " + "The old name will be removed in DRF v3.14.", + RemovedInDRF314Warning, stacklevel=2 + ) + return self.map_choicefield(field) + + def _get_pagination_parameters(self, path, method): + warnings.warn( + "Method `_get_pagination_parameters()` has been renamed to `get_pagination_parameters()`. " + "The old name will be removed in DRF v3.14.", + RemovedInDRF314Warning, stacklevel=2 + ) + return self.get_pagination_parameters(path, method) + + def _allows_filters(self, path, method): + warnings.warn( + "Method `_allows_filters()` has been renamed to `allows_filters()`. " + "The old name will be removed in DRF v3.14.", + RemovedInDRF314Warning, stacklevel=2 + ) + return self.allows_filters(path, method) diff --git a/drf_turbo/parsers.pyx b/drf_turbo/parsers.pyx new file mode 100755 index 0000000..c019cde --- /dev/null +++ b/drf_turbo/parsers.pyx @@ -0,0 +1,91 @@ +from django.conf import settings +from drf_turbo.exceptions import ParseError +import json +import codecs + + +try: + import ujson +except ImportError: + ujson = None + +try: + import orjson +except ImportError: + orjson = None + +cimport cython + +cdef class BaseParser: + """ + All parsers should extend `BaseParser`, specifying a `media_type` + attribute, and overriding the `.parse()` method. + """ + media_type = None + + cpdef dict parse(self, stream, media_type=None, parser_context=None): + """ + Given a stream to read from, return the parsed representation. + Should return parsed data, or a `DataAndFiles` object consisting of the + parsed data and files. + """ + raise NotImplementedError(".parse() must be overridden.") + + +@cython.final +cdef class JSONParser(BaseParser): + """ + Parses JSON-serialized data. + """ + media_type = 'application/json' + + cpdef inline dict parse(self, stream, media_type=None, parser_context=None): + parser_context = parser_context or {} + encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) + + try: + data = stream.read().decode(encoding) + return ujson.loads(data) + except ValueError as exc: + raise ParseError('JSON parse error - %s' % str(exc)) + + +@cython.final +cdef class UJSONParser(BaseParser): + """ + Parses JSON-serialized data by ujson parser. + """ + + media_type = "application/json" + + cpdef inline dict parse(self, stream, media_type=None, parser_context=None) : + assert ujson is not None, "ujson must be installed to use UJSONParser" + parser_context = parser_context or {} + encoding = parser_context.get("encoding", settings.DEFAULT_CHARSET) + + try: + data = stream.read().decode(encoding) + return ujson.loads(data) + except ValueError as exc: + raise ParseError('ORJSON parse error - %s' % str(exc)) + + +@cython.final +cdef class ORJSONParser(BaseParser): + """ + Parses JSON-serialized data by orjson parser. + """ + + media_type = "application/json" + + cpdef inline dict parse(self, stream, media_type=None, parser_context=None) : + assert orjson is not None, "orjson must be installed to use ORJSONParser" + parser_context = parser_context or {} + encoding = parser_context.get("encoding", settings.DEFAULT_CHARSET) + + try: + data = stream.read().decode(encoding) + return orjson.loads(data) + except ValueError as exc: + raise ParseError('ORJSON parse error - %s' % str(exc)) + diff --git a/drf_turbo/renderers.pyx b/drf_turbo/renderers.pyx new file mode 100755 index 0000000..610278b --- /dev/null +++ b/drf_turbo/renderers.pyx @@ -0,0 +1,156 @@ +from django.http.multipartparser import parse_header +from rest_framework.utils import encoders, json +from rest_framework.compat import ( + INDENT_SEPARATORS, LONG_SEPARATORS, SHORT_SEPARATORS +) +try: + import orjson +except ImportError: + orjson = None + +try: + import ujson +except ImportError: + ujson = None + + +cimport cython + +cdef zero_as_none(value): + return None if value == 0 else value + +cdef class BaseRenderer: + """ + All renderers should extend this class, setting the `media_type` + and `format` attributes, and override the `.render()` method. + """ + media_type = None + format = None + charset = 'utf-8' + render_style = 'text' + + cpdef bytes render(self,data, accepted_media_type=None,renderer_context=None): + raise NotImplementedError('Renderer class requires .render() to be implemented') + + +@cython.final +cdef class JSONRenderer(BaseRenderer): + """ + Renderer which serializes to JSON. + """ + media_type = 'application/json' + format = 'json' + encoder_class = encoders.JSONEncoder + ensure_ascii = False + compact = True + strict = True + charset = None + + cpdef inline get_indent(self,unicode accepted_media_type, dict renderer_context): + if accepted_media_type: + base_media_type, params = parse_header(accepted_media_type.encode('ascii')) + try: + return zero_as_none(max(min(int(params['indent']), 8), 0)) + except (KeyError, ValueError, TypeError): + pass + return renderer_context.get('indent', None) + + cpdef inline bytes render(self, data, accepted_media_type=None, renderer_context=None): + """ + Render `data` into JSON, returning a bytestring. + """ + cdef basestring ret + if data is None: + return b'' + + renderer_context = renderer_context or {} + indent = self.get_indent(accepted_media_type, renderer_context) + + if indent is None: + separators = SHORT_SEPARATORS if self.compact else LONG_SEPARATORS + else: + separators = INDENT_SEPARATORS + + ret = json.dumps( + data, cls=self.encoder_class, + indent=indent, ensure_ascii=self.ensure_ascii, + allow_nan=not self.strict, separators=separators + ) + + ret = ret.replace('\u2028', '\\u2028').replace('\u2029', '\\u2029') + return ret.encode() + +@cython.final +cdef class UJSONRenderer(BaseRenderer): + """ + Renderer which serializes to JSON. + """ + media_type = 'application/json' + format = 'json' + ensure_ascii = False + escape_forward_slashes= False + encode_html_chars= False + charset = None + + + cpdef inline get_indent(self,unicode accepted_media_type, dict renderer_context): + cdef dict params + if accepted_media_type: + _, params = parse_header(accepted_media_type.encode('ascii')) + try: + return zero_as_none(max(min(int(params['indent']), 8), 0)) + except (KeyError, ValueError, TypeError): + pass + + return renderer_context.get('indent', None) + + cpdef inline bytes render(self, data,accepted_media_type=None,renderer_context=None): + """ + Render `data` into JSON, returning a bytestring. + """ + assert ujson is not None, "ujson must be installed to use UJSONRenderer" + cdef basestring ret + if data is None: + return b'' + + renderer_context = renderer_context or {} + indent = self.get_indent(accepted_media_type, renderer_context) + + + ret = ujson.dumps( + data, + indent=indent or 0, + ensure_ascii=self.ensure_ascii, + encode_html_chars=self.encode_html_chars, + escape_forward_slashes=self.escape_forward_slashes, + ) + ret = ret.replace("\u2028", "\\u2028").replace("\u2029", "\\u2029") + return bytes(ret.encode("utf-8")) + +@cython.final +cdef class ORJSONRenderer(BaseRenderer): + """ + Renderer which serializes to JSON. + """ + + media_type = "application/json" + format = "json" + charset = None + + + + cpdef inline bytes render(self, data,accepted_media_type=None,renderer_context=None): + """ + Render `data` into JSON, returning a bytestring. + """ + assert orjson is not None, "orjson must be installed to use ORJSONParser" + if data is None: + return b'' + + return orjson.dumps( + data, + option = orjson.OPT_SERIALIZE_UUID | \ + orjson.OPT_SERIALIZE_NUMPY | \ + orjson.OPT_SERIALIZE_DATACLASS | \ + orjson.OPT_NON_STR_KEYS, + ) diff --git a/drf_turbo/response.pyx b/drf_turbo/response.pyx new file mode 100755 index 0000000..82c6fc3 --- /dev/null +++ b/drf_turbo/response.pyx @@ -0,0 +1,73 @@ +from django.http import HttpResponse +from typing import Optional,Any +from rest_framework import status +from django.core.serializers.json import DjangoJSONEncoder +import json + +try: + import ujson +except ImportError: + ujson = None + +try: + import orjson +except ImportError: + orjson = None + +class JSONResponse(HttpResponse): + def __init__(self, data : Any,**kwargs) -> None: + kwargs.setdefault('content_type', 'application/json') + data = json.dumps( + data, + cls=DjangoJSONEncoder, + ensure_ascii=False, + allow_nan=False, + indent=None, + separators=(",", ":"), + ).encode("utf-8") + super().__init__(content=data, **kwargs) + + +class UJSONResponse(HttpResponse): + def __init__(self, data : Any,**kwargs) -> None: + assert ujson is not None, "ujson must be installed to use UJSONResponse" + kwargs.setdefault('content_type', 'application/json') + data = ujson.dumps(data, ensure_ascii=False).encode("utf-8") + super().__init__(content=data, **kwargs) + + +class ORJSONResponse(HttpResponse): + def __init__(self, data : Any,**kwargs) -> None: + assert orjson is not None, "orjson must be installed to use ORJSONResponse" + kwargs.setdefault('content_type', 'application/json') + data = orjson.dumps(data) + super().__init__(content=data, **kwargs) + + +class SuccessResponse : + + def __new__(cls,data,message:Optional[str]=None,status_code:Optional[int]=None,default:Any=JSONResponse) -> Any : + if not message : + message = 'Success' + if not status_code : + status_code = status.HTTP_200_OK + return default(dict([ + ('message', message), + ('data', data), + ('error' , False) + ]),status=status_code) + + + +class ErrorResponse : + + def __new__(cls,data,message:Optional[str]=None,status_code:Optional[int]=None,default: Any=JSONResponse) -> Any : + if not message : + message = 'Bad request' + if not status_code : + status_code = status.HTTP_400_BAD_REQUEST + return default(dict([ + ('message', message), + ('data', data), + ('error' , True ) + ]),status=status_code) \ No newline at end of file diff --git a/drf_turbo/serializer.pxd b/drf_turbo/serializer.pxd new file mode 100755 index 0000000..4fa6586 --- /dev/null +++ b/drf_turbo/serializer.pxd @@ -0,0 +1,31 @@ +from drf_turbo.fields cimport Field +cimport cython_metaclass + + +cdef class BaseSerializer(Field) : + cdef: + object instance + public bint many + object data + public dict context + readonly object only + readonly object exclude + public bint partial + + cpdef bint is_valid(self,bint raise_exception=*) except -1 + cpdef dict get_initial_data(self) + + + + +cdef class Serializer(BaseSerializer): + cdef inline dict _parse_nested_fields(self,object fields) + cdef inline void _select_nested_fields(self,Serializer serializer,object fields,basestring action,bint is_nested=*) + cdef inline object _fields_to_include(self,Serializer serializer,object fields) + cdef inline object _fields_to_exclude(self,Serializer serializer,object fields,bint is_nested) + cdef dict _serialize(self,object instance,dict fields) + cpdef serialize(self,instance,dict context) + cdef dict _deserialize(self, object data,dict fields) + cpdef deserialize(self,object data,dict context) + cpdef run_validation(self,object data,dict context) + cpdef validate(self,object data) diff --git a/drf_turbo/serializer.pyx b/drf_turbo/serializer.pyx new file mode 100755 index 0000000..94356f1 --- /dev/null +++ b/drf_turbo/serializer.pyx @@ -0,0 +1,509 @@ +# cython: language_level=3 +# cython: embedsignature=True +# cython: wraparound=False +# cython: nonecheck=False +# cython: boundscheck=False + +from django.utils.functional import cached_property +from drf_turbo.fields cimport Field,RelatedField,SkipField,NO_DEFAULT +from drf_turbo.exceptions import * +from drf_turbo.utils import * +from django.core.exceptions import ValidationError as DjangoValidationError +from collections.abc import Mapping +import traceback +cimport cython + +cdef class BaseSerializer(Field): + """ + Base class for all serializers. + + :param instance: The instance to be serialized. + :param many: Serialize single object or a list of objects. + :param data: The data to be deserialized. + :param context: The context dictionary passed to the serializer. + :param partial: If set to True, partial update will be allowed. + :param only: A list of fields to be included in the serialized data. + :param exclude: A list of fields to be excluded from the serialized data. + :param kwargs: Extra keyword arguments. + + """ + def __init__( + self, + object instance = None, + bint many = False, + object data = None, + dict context = None, + object only = None, + object exclude = None, + bint partial = False, + **kwargs + ): + if only is not None and exclude is not None : + raise OnlyAndExcludeError('You should use either "only" or "exclude"') + if only is not None and not is_collection(only): + raise StringNotCollectionError('"only" should be a list of strings') + if exclude is not None and not is_collection(exclude): + raise StringNotCollectionError('"exclude" should be a list of strings') + super().__init__(**kwargs) + self._instance =instance + self._data = data + self.many = many + self._initial_data = None + self._initial_instance = None + self.context = context + self.only = only + self.exclude = exclude + self.partial = partial + + + cpdef bint is_valid(self,bint raise_exception=False) except -1: + """ + Whether the data is valid. + + :param raise_exception: Whether to raise an exception if the data is invalid. + + """ + assert hasattr(self, '_data'), ( + 'Cannot call `.is_valid()` as no `data=` keyword argument was ' + 'passed when instantiating the serializer instance.' + ) + if not hasattr(self, '_validated_data'): + try: + self._validated_data =self.run_validation(self._data,self.context) + except ValidationError as exc: + self._validated_data = {} + self._errors = exc.detail + else: + self._errors = {} + + if self._errors and raise_exception: + raise ValidationError(self.errors) + return not bool(self._errors) + + + def save(self,**kwargs): + """ + Create or update a model instance. + + :param kwargs: Extra keyword arguments. + """ + assert not self._initial_data, ( + "You cannot call `.save()` after accessing `serializer.data`." + "If you need to access data before committing to the database then " + "inspect 'serializer.validated_data' instead. " + ) + + validated_data = {**self.validated_data, **kwargs} + if self._instance is not None: + self._instance = self.update(self._instance, validated_data) + else: + self._instance = self.create(validated_data) + + return self._instance + + + @property + def errors(self): + """ + Return the dictionary of errors raised during validation. + """ + if not hasattr(self, '_errors'): + msg = 'You must call `.is_valid()` before accessing `.errors`.' + raise AssertionError(msg) + return self._errors + + cpdef dict get_initial_data(self): + """ + Return the initial data for the fields. + """ + if self._data is not None: + + if not isinstance(self._data, Mapping): + return dict() + return dict([ + (name, self._data.get(name,NO_DEFAULT)) + for name, field in self.fields.items() + if (self._data.get(name,NO_DEFAULT) is not NO_DEFAULT) and + not field.read_only + ]) + + return dict([ + (name, field.get_initial()) + for name,field in self.fields.items() + if not field.read_only + ]) + + @property + def data(self): + """ + Return the serialized data on the serializer. + """ + if not self._initial_data : + if self._instance is not None and not getattr(self, '_errors', None): + self._initial_data = self.serialize(self._instance,self.context) + elif hasattr(self, '_validated_data') and not getattr(self, '_errors', None): + self._initial_data = self.serialize(self.validated_data,self.context) + + else: + self._initial_data = self.get_initial_data() + + return self._initial_data + + @property + def instance(self): + """ + Return the model instance that is being serialized. + """ + return self._instance + + + @property + def validated_data(self): + """ + Return the validated data on the serializer. + """ + if not hasattr(self, '_validated_data'): + msg = 'You must call `.is_valid()` before accessing `.validated_data`.' + raise AssertionError(msg) + return self._validated_data + +cdef class Serializer(BaseSerializer): + + def __getmetaclass__(_): + from drf_turbo.meta import SerializerMetaclass + return SerializerMetaclass + + def get_fields(self): + """ + Return the dict of field names -> field instances that should be added to the serializer. + """ + return deepcopy(self._fields) + + + def fields(self): + """ + This is a shortcut for accessing the dict on the *fields* attribute. + """ + cdef str key + cdef Field value + cdef dict fields = {} + for key, value in self.get_fields().items(): + fields[key] = value + fields[key].bind(key,self) + return fields + + from forbiddenfruit import curse as _curse + _curse(Serializer,'fields',cached_property(fields)) + Serializer.fields.__set_name__(Serializer, 'fields') + + + @property + def _writable_fields(self): + """ + @property _writable_fields + Return a list of all writable fields. + """ + cdef str k + cdef Field v + return {k: v for k, v in self.fields.items() if not v.read_only} + + + @property + def _readable_fields(self): + """ + Return a list of all readable fields. + """ + cdef str k + cdef Field v + return {k: v for k, v in self.fields.items() if not v.write_only} + + @property + def _only_fields(self): + """ + Return a list of all fields that have been specified in the `only` option. + """ + only = self.only or self.context.get('request').GET.get('only').split(',') + is_nested = any('__' in field for field in only) + if is_nested : + fields = self._parse_nested_fields(only) + self._select_nested_fields(self,fields,action='include') + else: + self._fields_to_include(self,only) + return self.fields + + @property + def _exclude_fields(self): + """ + Return a list of all fields that have been specified in the `exclude` option. + """ + exclude = self.exclude or self.context.get('request').GET.get('exclude').split(',') + is_nested = any('__' in field for field in exclude) + if is_nested: + fields = self._parse_nested_fields(exclude) + self._select_nested_fields(self,fields,action='exclude',is_nested=True) + else: + self._fields_to_exclude(self,exclude,is_nested=False) + + return self.fields + + cdef inline dict _parse_nested_fields(self,object fields): + """ + Parse nested fields + + :param fields: A list of fields to parse. + """ + cdef dict field_object = {"fields": []} + cdef str f + for f in fields: + obj = field_object + nested_fields = f.split("__") + for v in nested_fields: + if v not in obj["fields"]: + obj["fields"].append(v) + if nested_fields.index(v) < len(nested_fields) - 1: + obj[v] = obj.get(v, {"fields": []}) + obj = obj[v] + return field_object + + cdef inline void _select_nested_fields(self,Serializer serializer,object fields,basestring action,bint is_nested=False): + for k in fields: + if k == "fields": + if action == 'include' : + self._fields_to_include(serializer, fields[k]) + elif action == 'exclude': + self._fields_to_exclude(serializer,fields[k],is_nested) + + else: + self._select_nested_fields(serializer.fields[k], fields[k],action=action,is_nested=is_nested) + + cdef inline object _fields_to_include(self,Serializer serializer,object fields): + """ + Include fields. + + :param serializer: The serializer to include fields on. + :param fields: A list of fields to include. + """ + allowed = set(fields) + existing = set(serializer.fields.keys()) + for field_name in existing - allowed: + if field_name in serializer.fields : + serializer.fields.pop(field_name) + return serializer.fields + + cdef inline object _fields_to_exclude(self,Serializer serializer,object fields,bint is_nested): + """ + Exclude fields. + + :param serializer: The serializer to exclude fields on. + :param fields: A list of fields to exclude. + :param is_nested: Whether the fields are nested. + """ + excluded = set(fields) + existing = set(serializer.fields.keys()) + for field_name in excluded : + if is_nested : + if field_name in existing : + if field_name in self.fields.keys() : + if not issubclass(self.fields[field_name].__class__,Serializer): + serializer.fields.pop(field_name) + else: + serializer.fields.pop(field_name) + else: + if field_name in serializer.fields: + serializer.fields.pop(field_name) + return serializer.fields + + + cdef dict _serialize(self,object instance,dict fields): + + cdef str name + cdef Field field + cdef dict ret = {} + cdef bint is_dict + for name,field in fields.items(): + attr = field.attr if field.attr and not '.' in field.attr else name + if field.is_method_field: + result = field.method_getter(attr,self.__class__)(self,instance) + else: + try: + if isinstance(field,RelatedField): + result = field.get_attribute(instance,attr + '_id') + else: + result = field.get_attribute(instance) + if hasattr(result,'all'): + result = result.all() + + except SkipField : + continue + + if result is not None: + if field.call: + result = result() + result = field.serialize(result,self.context) + ret[attr] = result + + return ret + + + cpdef serialize(self,object instance,dict context): + """ + Serialize a model instance. + + :param instance: Model instance to serialize. + :param context: Context data. + """ + try: + only = self.context.get('request').GET.get('only') + except : + only = None + try: + exclude = self.context.get('request').GET.get('exclude') + except : + exclude = None + if self.only or only is not None : + fields = self._only_fields + elif self.exclude or exclude is not None: + fields = self._exclude_fields + fields = self._readable_fields + if self.many : + return [self._serialize(o,fields) for o in instance] + return self._serialize(instance,fields) + + cdef dict _deserialize(self, object data,dict fields): + if not isinstance(data, Mapping): + raise ValidationError( + 'Invalid data type: %s' % type(data).__name__ + ) + cdef dict ret = {} + cdef dict errors = {} + cdef str name + for name,field in fields.items(): + attr = field.attr if field.attr and not '.' in field.attr else name + validate_method = getattr(self, 'validate_' + attr, None) + value = data.get(name,NO_DEFAULT) + try: + validated_value = field.run_validation(value,self.context) + if validate_method is not None: + validated_value = validate_method(validated_value) + + except ValidationError as exc: + errors[name] = exc.detail + except DjangoValidationError as exc: + errors[name] = get_error_detail(exc) + except SkipField: + continue + else: + ret[attr] = validated_value + + if errors: + raise ValidationError(errors) + return ret + + cpdef deserialize(self,object data,dict context): + """ + Given a dictionary-like structure, build a dictionary of deserialized + fields and return a model instance. + + :param data: The data to deserialize. + :param context: The context for the request. + """ + fields = self._writable_fields + if self.many : + return [self._deserialize(o,fields) for o in data] + return self._deserialize(data,fields) + + cpdef run_validation(self,object data,dict context): + """ + Validate an entire bundle of data. + + :param data: The data to validate. + :param context: The context for the request. + """ + value = self.validate(self.deserialize(data,context)) + return value + + cpdef validate(self,object data): + """ + Validate a dictionary of deserialized field values. + + :param data: A dictionary of deserialized field values. + """ + return data + + + + +cdef class ModelSerializer(Serializer): + + def __getmetaclass__(_): + from drf_turbo.meta import ModelSerializerMetaclass + return ModelSerializerMetaclass + + cpdef create(self, validated_data): + """ + Create a model instance. + + :param validated_data: A dictionary of validated data. + """ + model = self.Meta.model + opts = model._meta.concrete_model._meta + many_to_many_fields = [field.name for field in opts.many_to_many if field.serialize] + m2m_fields = {} + data = validated_data.copy() + for attr, value in data.items(): + if attr in many_to_many_fields : + m2m_fields[attr] = validated_data.pop(attr) + + + try: + instance = model._default_manager.create(**validated_data) + except TypeError: + tb = traceback.format_exc() + msg = ( + 'Got a `TypeError` when calling `%s.%s.create()`. ' + 'This may be because you have a writable field on the ' + 'serializer class that is not a valid argument to ' + '`%s.%s.create()`. You may need to make the field ' + 'read-only, or override the %s.create() method to handle ' + 'this correctly.\nOriginal exception was:\n %s' % + ( + model.__name__, + model._default_manager.name, + model.__name__, + model._default_manager.name, + self.__class__.__name__, + tb + ) + ) + raise TypeError(msg) + # Save many-to-many relationships after the instance is created. + if m2m_fields: + for field_name, value in m2m_fields.items(): + field = getattr(instance, field_name) + field.set(value) + + return instance + + + cpdef update(self, instance, validated_data): + """ + Update a model instance. + + :param instance: Model instance to update. + :param validated_data: A dictionary of deserialized data. + """ + opts = instance._meta.concrete_model._meta + many_to_many_fields =[field.name for field in opts.many_to_many if field.serialize] + m2m_fields = [] + for attr, value in validated_data.items(): + if attr in many_to_many_fields : + m2m_fields.append((attr, value)) + else: + setattr(instance, attr, value) + + instance.save() + for attr, value in m2m_fields: + field = getattr(instance, attr) + field.set(value) + + return instance diff --git a/drf_turbo/templates/docs.html b/drf_turbo/templates/docs.html new file mode 100755 index 0000000..e48f666 --- /dev/null +++ b/drf_turbo/templates/docs.html @@ -0,0 +1,38 @@ + + + + + Swagger + + + + + + +
+ + + + + \ No newline at end of file diff --git a/drf_turbo/utils.pyx b/drf_turbo/utils.pyx new file mode 100755 index 0000000..07f1e8e --- /dev/null +++ b/drf_turbo/utils.pyx @@ -0,0 +1,80 @@ +#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True + +import collections +from django.utils.encoding import force_str +from django.core.exceptions import ObjectDoesNotExist +import operator + + +cpdef bint is_iterable_and_not_string(arg): + return ( + isinstance(arg, collections.Iterable) + and not isinstance(arg, str) + ) + +cpdef bint is_collection(obj): + """Return True if ``obj`` is a collection type, e.g list, tuple, queryset.""" + return is_iterable_and_not_string(obj) and not isinstance(obj, dict) + + +cpdef get_error_detail(exc_info): + """ + Translate django ValidationError to readable errors + """ + cdef dict error_dict + cdef list errors + cdef str k + + try: + error_dict = exc_info.error_dict + except AttributeError: + return [ + (error.message % error.params) if error.params else error.message + for error in exc_info.error_list] + return { + k: [ + (error.message % error.params) if error.params else error.message + for error in errors + ] for k, errors in error_dict.items() + } + + +cpdef get_execption_detail(exception): + if isinstance(exception,(list,tuple)): + return [get_execption_detail(item) for item in exception] + + elif isinstance(exception,dict): + return {key : get_execption_detail(value) for key,value in exception.items()} + + return force_str(exception) + + +cpdef get_attribute(instance, attrs): + for attr in attrs : + try: + if isinstance(instance, dict): + instance = instance[attr] + else: + instance = getattr(instance, attr) + except ObjectDoesNotExist : + return None + + if callable(instance): + try: + instance = instance() + except (AttributeError, KeyError) as exc: + raise ValueError( + "Unable to resolve attribute '%s' on %s: %s" % ( + attr, instance, exc + ) + ) + return instance + + +cpdef deepcopy(dict data): + cdef dict output = data.copy() + cdef str key + for key, value in output.items(): + output[key] = deepcopy(value) if isinstance(value, dict) else value + return output + diff --git a/input.txt b/input.txt new file mode 100644 index 0000000..ac16db6 --- /dev/null +++ b/input.txt @@ -0,0 +1,4 @@ +! TRAVIS input file +! Created with TRAVIS version compiled at Mar 22 2020 16:01:08 +! Source code version: Jan 01 2019 +! Input file written at Wed Nov 10 16:05:05 2021. diff --git a/requirements.txt b/requirements.txt new file mode 100755 index 0000000..9ceef0b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +forbiddenfruit +djangorestframework +pyyaml +uritemplate +cython +djangorestframework-simplejwt +psycopg2-binary diff --git a/runtests.py b/runtests.py new file mode 100755 index 0000000..98f34c0 --- /dev/null +++ b/runtests.py @@ -0,0 +1,51 @@ +#! /usr/bin/env python3 +import sys + +import pytest + + +def split_class_and_function(string): + class_string, function_string = string.split('.', 1) + return "%s and %s" % (class_string, function_string) + + +def is_function(string): + # `True` if it looks like a test function is included in the string. + return string.startswith('test_') or '.test_' in string + + +def is_class(string): + # `True` if first character is uppercase - assume it's a class name. + return string[0] == string[0].upper() + + +if __name__ == "__main__": + if len(sys.argv) > 1: + pytest_args = sys.argv[1:] + first_arg = pytest_args[0] + + try: + pytest_args.remove('--coverage') + except ValueError: + pass + else: + pytest_args = [ + '--cov', '.', + '--cov-report', 'xml', + ] + pytest_args + + if first_arg.startswith('-'): + # `runtests.py [flags]` + pytest_args = ['tests'] + pytest_args + elif is_class(first_arg) and is_function(first_arg): + # `runtests.py TestCase.test_function [flags]` + expression = split_class_and_function(first_arg) + pytest_args = ['tests', '-k', expression] + pytest_args[1:] + elif is_class(first_arg) or is_function(first_arg): + # `runtests.py TestCase [flags]` + # `runtests.py test_function [flags]` + pytest_args = ['tests', '-k', pytest_args[0]] + pytest_args[1:] + else: + pytest_args = [] + + sys.exit(pytest.main(pytest_args)) diff --git a/setup.cfg b/setup.cfg new file mode 100755 index 0000000..a311832 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,18 @@ +[bumpversion] +current_version = 0.1.0 +commit = True +tag = True + +[bumpversion:file:setup.py] +search = version='{current_version}' +replace = version='{new_version}' + +[bumpversion:file:drf_turbo/__init__.py] +search = __version__ = '{current_version}' +replace = __version__ = '{new_version}' + +[bdist_wheel] +universal = 1 + +[flake8] +exclude = docs diff --git a/setup.py b/setup.py new file mode 100755 index 0000000..4feaf22 --- /dev/null +++ b/setup.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python + +"""The setup script.""" + +from setuptools import setup, find_packages +from distutils.extension import Extension +from Cython.Build import cythonize + + +with open('README.rst') as readme_file: + readme = readme_file.read() + +with open('HISTORY.rst') as history_file: + history = history_file.read() + +with open('requirements.txt') as requirements_file: + requirements = requirements_file.read().splitlines() + + +setup( + author="Michael Gendy", + author_email='mngback@gmail.com', + python_requires='>=3.6', + classifiers=[ + 'Development Status :: 4 - Beta', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: MIT License', + 'Natural Language :: English', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + ], + description="An alternative serializer implementation for REST framework written in cython built for speed.", + install_requires=requirements, + license="MIT license", + long_description=readme + '\n\n' + history, + include_package_data=True, + keywords='drf_turbo', + name='drf-turbo', + packages=find_packages(include=['drf_turbo', 'drf_turbo.*']), + test_suite='tests', + url='https://github.com/Mng-dev-ai/drf-turbo', + version='0.1.1', + zip_safe=False, + ext_modules=cythonize(["drf_turbo/*.pyx"]), + +) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100755 index 0000000..f0922bb --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Unit test package for drf_turbo.""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100755 index 0000000..6acd070 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,60 @@ +import os +import sys + +import django +from django.core import management + + +def pytest_addoption(parser): + parser.addoption('--no-pkgroot', action='store_true', default=False, + help='Remove package root directory from sys.path, ensuring that ' + 'rest_framework is imported from the installed site-packages. ' + 'Used for testing the distribution.') + parser.addoption('--staticfiles', action='store_true', default=False, + help='Run tests with static files collection, using manifest ' + 'staticfiles storage. Used for testing the distribution.') + + +def pytest_configure(config): + from django.conf import settings + + # USE_L10N is deprecated, and will be removed in Django 5.0. + use_l10n = {"USE_L10N": True} if django.VERSION < (4, 0) else {} + settings.configure( + DEBUG_PROPAGATE_EXCEPTIONS=True, + DATABASES={ + 'default': { + 'ENGINE': 'django.db.backends.sqlite3', + 'NAME': ':memory:' + }, + 'secondary': { + 'ENGINE': 'django.db.backends.sqlite3', + 'NAME': ':memory:' + } + }, + SITE_ID=1, + SECRET_KEY='not very secret in tests', + USE_I18N=True, + MIDDLEWARE=( + 'django.middleware.common.CommonMiddleware', + 'django.contrib.sessions.middleware.SessionMiddleware', + 'django.contrib.auth.middleware.AuthenticationMiddleware', + 'django.contrib.messages.middleware.MessageMiddleware', + ), + INSTALLED_APPS=( + 'django.contrib.admin', + 'django.contrib.auth', + 'django.contrib.contenttypes', + 'django.contrib.sessions', + 'django.contrib.sites', + 'django.contrib.staticfiles', + 'rest_framework', + 'rest_framework.authtoken', + 'tests', + ), + PASSWORD_HASHERS=( + 'django.contrib.auth.hashers.MD5PasswordHasher', + ), + **use_l10n, + ) + django.setup() diff --git a/tests/test_fields.py b/tests/test_fields.py new file mode 100755 index 0000000..d45ea10 --- /dev/null +++ b/tests/test_fields.py @@ -0,0 +1,1054 @@ +import drf_turbo as dt +from drf_turbo.exceptions import ValidationError +import pytest +import uuid +from decimal import Decimal +import datetime +from django.utils.timezone import utc +from rest_framework.test import APISimpleTestCase +from django.core.exceptions import ObjectDoesNotExist + + +NO_DEFAULT = object() + + +def get_items(mapping_or_list_of_two_tuples): + # Tests accept either lists of two tuples, or dictionaries. + if isinstance(mapping_or_list_of_two_tuples, dict): + # {value: expected} + return mapping_or_list_of_two_tuples.items() + # [(value, expected), ...] + return mapping_or_list_of_two_tuples + + +class FieldValues: + """ + Base class for testing valid and invalid input values. + """ + context = {} + + def test_valid_inputs(self): + """ + Ensure that valid values return the expected validated data. + """ + for input_value, expected_output in get_items(self.valid_inputs): + assert self.field.run_validation(input_value,self.context) == expected_output, \ + 'input value: {}'.format(repr(input_value)) + + def test_invalid_inputs(self): + """ + Ensure that invalid values raise the expected validation error. + """ + for input_value, expected_failure in get_items(self.invalid_inputs): + with pytest.raises(ValidationError) as exc_info: + self.field.run_validation(input_value,self.context) + assert exc_info.value.detail == expected_failure, \ + 'input value: {}'.format(repr(input_value)) + + def test_outputs(self): + for output_value, expected_output in get_items(self.outputs): + assert self.field.serialize(output_value,self.context) == expected_output, \ + 'output value: {}'.format(repr(output_value)) + +class TestStrField(FieldValues): + """ + Valid and invalid values for `CharField`. + """ + valid_inputs = { + 1: '1', + 'abc': 'abc' + } + invalid_inputs = { + (): ['Not a valid string.'], + True: ['Not a valid string.'], + '': ['May not be blank.'] + } + outputs = { + 1: '1', + 'abc': 'abc' + } + field = dt.StrField() + + def test_trim_whitespace_default(self): + field =dt.StrField() + assert field.deserialize(' abc ',self.context) == 'abc' + + def test_trim_whitespace_disabled(self): + field = dt.StrField(trim_whitespace=False) + assert field.deserialize(' abc ',self.context) == ' abc ' + + def test_disallow_blank_with_trim_whitespace(self): + field = dt.StrField(allow_blank=False, trim_whitespace=True) + + with pytest.raises(ValidationError) as exc_info: + field.run_validation(' ',self.context) + assert exc_info.value.detail == ['May not be blank.'] + + def test_null_bytes(self): + field = dt.StrField() + + for value in ('\0', 'foo\0', '\0foo', 'foo\0foo'): + with pytest.raises(ValidationError) as exc_info: + field.run_validation(value,self.context) + assert exc_info.value.detail == [ + 'Null characters are not allowed.' + ] + +class TestEmailField(FieldValues): + """ + Valid and invalid values for `EmailField`. + """ + valid_inputs = { + 'example@example.com': 'example@example.com', + } + invalid_inputs = { + 'examplecom': ['Enter a valid email address.'] + } + outputs = {} + field = dt.EmailField() + +class TestRegexField(FieldValues): + """ + Valid and invalid values for `RegexField`. + """ + valid_inputs = { + 'a9': 'a9', + } + invalid_inputs = { + 'A9': ["This value does not match the required pattern."] + } + outputs = {} + field = dt.RegexField(regex='[a-z][0-9]') + + +class TestURLField(FieldValues): + """ + Valid and invalid values for `URLField`. + """ + valid_inputs = { + 'http://example.com': 'http://example.com', + } + invalid_inputs = { + 'example.com': ['Enter a valid URL.'] + } + outputs = {} + field = dt.URLField() + + +class TestUUIDField(FieldValues): + """ + Valid and invalid values for `UUIDField`. + """ + valid_inputs = { + '825d7aeb-05a9-45b5-a5b7-05df87923cda': uuid.UUID('825d7aeb-05a9-45b5-a5b7-05df87923cda'), + } + invalid_inputs = { + '825d7aeb-05a9-45b5-a5b7': ["Not a valid UUID."], + } + outputs = { + uuid.UUID('825d7aeb-05a9-45b5-a5b7-05df87923cda'): '825d7aeb-05a9-45b5-a5b7-05df87923cda' + } + field = dt.UUIDField() + + +class TestURLField(FieldValues): + """ + Valid and invalid values for `URLField`. + """ + valid_inputs = { + 'http://example.com': 'http://example.com', + } + invalid_inputs = { + 'example.com': ['Enter a valid URL.'] + } + outputs = {} + field = dt.URLField() + + +class TestSlugField(FieldValues): + """ + Valid and invalid values for `SlugField`. + """ + valid_inputs = { + 'slug-99': 'slug-99', + } + invalid_inputs = { + 'slug 99': ['Not a valid slug.'] + } + outputs = {} + field = dt.SlugField() + + def test_allow_unicode_true(self): + field = dt.SlugField(allow_unicode=True) + + validation_error = False + try: + field.run_validation('slug-99-\u0420',self.context) + except dt.ValidationError: + validation_error = True + + assert not validation_error + +class TestIntField(FieldValues): + """ + Valid and invalid values for `IntegerField`. + """ + valid_inputs = { + '1': 1, + 1: 1, + } + invalid_inputs = { + 0.5: ['A valid integer is required.'], + 'abc': ['A valid integer is required.'], + } + outputs = { + '1': 1, + '0': 0, + } + field = dt.IntField() + +class TestMinMaxIntField(FieldValues): + """ + Valid and invalid values for `IntegerField` with min and max limits. + """ + valid_inputs = { + '1': 1, + 3: 3, + } + invalid_inputs = { + 0: ['Ensure this value is greater than or equal to 1.'], + '4': ['Ensure this value is less than or equal to 3.'], + } + outputs = {} + field = dt.IntField(min_value=1, max_value=3) + + + +class TestFloatField(FieldValues): + """ + Valid and invalid values for `FloatField`. + """ + valid_inputs = { + '1': 1.0, + '0': 0.0, + } + invalid_inputs = { + 'abc': ["A valid number is required."] + } + outputs = { + '1': 1.0, + '0': 0.0, + } + field = dt.FloatField() + + +class TestMinMaxFloatField(FieldValues): + """ + Valid and invalid values for `FloatField` with min and max limits. + """ + valid_inputs = { + '1': 1, + 3: 3, + 1.0: 1.0, + } + invalid_inputs = { + 0.9: ['Ensure this value is greater than or equal to 1.'], + '3.1': ['Ensure this value is less than or equal to 3.'], + } + outputs = {} + field = dt.FloatField(min_value=1, max_value=3) + + + +class TestDecimalField(FieldValues): + """ + Valid and invalid values for `DecimalField`. + """ + valid_inputs = { + '12.3': Decimal('12.3'), + '0.1': Decimal('0.1'), + 10: Decimal('10'), + + } + invalid_inputs = ( + ('', ["A valid number is required."]), + (' ', ["A valid number is required."]), + ('abc', ["A valid number is required."]), + (Decimal('Nan'), ["A valid number is required."]), + ) + outputs = { + '1': '1.0', + '0': '0.0', + 1: '1.0', + 0: '0.0', + Decimal('1.09'): '1.1', + Decimal('0.04'): '0.0' + } + field = dt.DecimalField(max_digits=3, decimal_places=1) + + +class TestAllowNullDecimalField(FieldValues): + valid_inputs = { + '': None, + ' ': None, + } + invalid_inputs = {} + outputs = { + None: '', + } + field = dt.DecimalField(max_digits=3, decimal_places=1, allow_null=True) + +class TestAllowNullNoStringCoercionDecimalField(FieldValues): + valid_inputs = { + '': None, + ' ': None, + } + invalid_inputs = {} + outputs = { + None: None, + } + field = dt.DecimalField(max_digits=3, decimal_places=1, allow_null=True, coerce_to_string=False) + + +class TestMinMaxDecimalField(FieldValues): + """ + Valid and invalid values for `DecimalField` with min and max limits. + """ + valid_inputs = { + '10.0': Decimal('10.0'), + '20.0': Decimal('20.0'), + } + invalid_inputs = { + '9.9': ['Ensure this value is greater than or equal to 10.'], + '20.1': ['Ensure this value is less than or equal to 20.'], + } + outputs = {} + field = dt.DecimalField( + max_digits=3, decimal_places=1, + min_value=10, max_value=20 + ) + + +class TestNoDecimalPlaces(FieldValues): + valid_inputs = { + '0.12345': Decimal('0.12345'), + } + invalid_inputs = { + '0.1234567': ['Ensure that there are no more than 6 digits in total.'] + } + outputs = { + '1.2345': '1.2345', + '0': '0', + '1.1': '1.1', + } + field = dt.DecimalField(max_digits=6, decimal_places=None) + + +class TestNoMaxDigitsDecimalField(FieldValues): + field = dt.DecimalField( + max_value=100, min_value=0, + decimal_places=2, max_digits=None + ) + valid_inputs = { + '10': Decimal('10.00') + } + invalid_inputs = {} + outputs = {} + +class TestDateField(FieldValues): + """ + Valid and invalid values for `DateField`. + """ + valid_inputs = { + '2001-01-01': datetime.date(2001, 1, 1), + datetime.date(2001, 1, 1): datetime.date(2001, 1, 1), + } + invalid_inputs = { + 'abc': ['Not a valid date.'], + '2001-99-99': ['Not a valid date.'], + '2001': ['Not a valid date.'], + datetime.datetime(2001, 1, 1, 12, 00): ['Expected a date but got a datetime.'], + } + outputs = { + datetime.date(2001, 1, 1): '2001-01-01', + '2001-01-01': '2001-01-01', + str('2016-01-10'): '2016-01-10', + None: None, + '': None, + } + field = dt.DateField() + +class TestDateTimeField(FieldValues): + """ + Valid and invalid values for `DateTimeField`. + """ + valid_inputs = { + '2001-01-01 13:00': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc), + '2001-01-01T13:00': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc), + '2001-01-01T13:00Z': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc), + datetime.datetime(2001, 1, 1, 13, 00): datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc), + datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc): datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc), + } + invalid_inputs = { + 'abc': ['Not a valid datetime.'], + '2001-99-99T99:00': ['Not a valid datetime.'], + '2018-08-16 22:00-24:00': ['Not a valid datetime.'], + datetime.date(2001, 1, 1): ['Expected a datetime but got a date.'], + '9999-12-31T21:59:59.99990-03:00': ['Datetime value out of range.'], + } + outputs = { + datetime.datetime(2001, 1, 1, 13, 00): '2001-01-01T13:00:00Z', + datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc): '2001-01-01T13:00:00Z', + '2001-01-01T00:00:00': '2001-01-01T00:00:00', + str('2016-01-10T00:00:00'): '2016-01-10T00:00:00', + None: None, + '': None, + } + field = dt.DateTimeField(default_timezone=utc) + +class TestCustomInputFormatDateTimeField(FieldValues): + """ + Valid and invalid values for `DateTimeField` with a custom input format. + """ + valid_inputs = { + '1:35pm, 1 Jan 2001': datetime.datetime(2001, 1, 1, 13, 35, tzinfo=utc), + } + invalid_inputs = { + '2001-01-01T20:50': ['Not a valid datetime.'] + } + outputs = {} + field = dt.DateTimeField(default_timezone=utc, input_formats=['%I:%M%p, %d %b %Y']) + + +class TestCustomOutputFormatDateTimeField(FieldValues): + """ + Values for `DateTimeField` with a custom output format. + """ + valid_inputs = {} + invalid_inputs = {} + outputs = { + datetime.datetime(2001, 1, 1, 13, 00): '01:00PM, 01 Jan 2001', + } + field = dt.DateTimeField(format='%I:%M%p, %d %b %Y') + + +class TestNoOutputFormatDateTimeField(FieldValues): + """ + Values for `DateTimeField` with no output format. + """ + valid_inputs = {} + invalid_inputs = {} + outputs = { + datetime.datetime(2001, 1, 1, 13, 00): datetime.datetime(2001, 1, 1, 13, 00), + } + field = dt.DateTimeField(format=None) + + +class TestCustomOutputFormatDateField(FieldValues): + """ + Values for `DateField` with a custom output format. + """ + valid_inputs = {} + invalid_inputs = {} + outputs = { + datetime.date(2001, 1, 1): '01 Jan 2001' + } + field = dt.DateField(format='%d %b %Y') + + +class TestTimeField(FieldValues): + """ + Valid and invalid values for `TimeField`. + """ + valid_inputs = { + '13:00': datetime.time(13, 00), + datetime.time(13, 00): datetime.time(13, 00), + } + invalid_inputs = { + 'abc': ['Not a valid time.'], + '99:99': ['Not a valid time.'], + } + outputs = { + datetime.time(13, 0): '13:00:00', + datetime.time(0, 0): '00:00:00', + '00:00:00': '00:00:00', + None: None, + '': None, + } + field = dt.TimeField() + + +class TestCustomInputFormatTimeField(FieldValues): + """ + Valid and invalid values for `TimeField` with a custom input format. + """ + valid_inputs = { + '1:00pm': datetime.time(13, 00), + } + invalid_inputs = { + '13:00': ['Not a valid time.'], + } + outputs = {} + field = dt.TimeField(input_formats=['%I:%M%p']) + + +class TestCustomOutputFormatTimeField(FieldValues): + """ + Values for `TimeField` with a custom output format. + """ + valid_inputs = {} + invalid_inputs = {} + outputs = { + datetime.time(13, 00): '01:00PM' + } + field = dt.TimeField(format='%I:%M%p') + + +class TestNoOutputFormatTimeField(FieldValues): + """ + Values for `TimeField` with a no output format. + """ + valid_inputs = {} + invalid_inputs = {} + outputs = { + datetime.time(13, 00): datetime.time(13, 00) + } + field = dt.TimeField(format=None) + +class TestChoiceField(FieldValues): + """ + Valid and invalid values for `ChoiceField`. + """ + valid_inputs = { + 'poor': 'poor', + 'medium': 'medium', + 'good': 'good', + } + invalid_inputs = { + 'amazing': ['"amazing" is not a valid choice.'] + } + outputs = { + 'good': {'value' : 'good', 'display' : 'Good quality'}, + '': '', + 'amazing':{'display': 'amazing', 'value': 'amazing'} + } + field = dt.ChoiceField( + choices=[ + ('poor', 'Poor quality'), + ('medium', 'Medium quality'), + ('good', 'Good quality'), + ] + ) + + def test_allow_blank(self): + """ + If `allow_blank=True` then '' is a valid input. + """ + field = dt.ChoiceField( + allow_blank=True, + choices=[ + ('poor', 'Poor quality'), + ('medium', 'Medium quality'), + ('good', 'Good quality'), + ] + ) + output = field.run_validation('',self.context) + assert output == '' + + + +class TestChoiceFieldWithType(FieldValues): + """ + Valid and invalid values for a `Choice` field that uses an integer type, + instead of a char type. + """ + valid_inputs = { + '1': 1, + 3: 3, + } + invalid_inputs = { + 5: ['"5" is not a valid choice.'], + 'abc': ['"abc" is not a valid choice.'] + } + outputs = { + '1': {'display': 'Poor quality', 'value': 1}, + 1: {'display': 'Poor quality', 'value': 1}, + } + field = dt.ChoiceField( + choices=[ + (1, 'Poor quality'), + (2, 'Medium quality'), + (3, 'Good quality'), + ] + ) + +class TestChoiceFieldWithListChoices(FieldValues): + """ + Valid and invalid values for a `Choice` field that uses a flat list for the + choices, rather than a list of pairs of (`value`, `description`). + """ + valid_inputs = { + 'poor': 'poor', + 'medium': 'medium', + 'good': 'good', + } + invalid_inputs = { + 'awful': ['"awful" is not a valid choice.'] + } + outputs = { + 'good': {'display':'good','value' : 'good'} + } + field = dt.ChoiceField(choices=('poor', 'medium', 'good')) + +class TestMultipleChoiceField(FieldValues): + """ + Valid and invalid values for `MultipleChoiceField`. + """ + valid_inputs = { + (): set(), + ('aircon',): {'aircon'}, + ('aircon', 'manual'): {'aircon', 'manual'}, + } + invalid_inputs = { + 'abc': ['Expected a list of items but got type "str".'], + ('aircon', 'incorrect'): ['"incorrect" is not a valid choice.'] + } + outputs = [ + (['aircon', 'manual', 'incorrect'], {'aircon', 'manual', 'incorrect'}) + ] + field = dt.MultipleChoiceField( + choices=[ + ('aircon', 'AirCon'), + ('manual', 'Manual drive'), + ('diesel', 'Diesel'), + ] + ) + +class TestEmptyMultipleChoiceField(FieldValues): + """ + Invalid values for `MultipleChoiceField(allow_empty=False)`. + """ + valid_inputs = { + } + invalid_inputs = ( + ([], ['This selection may not be empty.']), + ) + outputs = [ + ] + field = dt.MultipleChoiceField( + choices=[ + ('consistency', 'Consistency'), + ('availability', 'Availability'), + ('partition', 'Partition tolerance'), + ], + allow_empty=False + ) + + + +class TestBooleanField(FieldValues): + """ + Valid and invalid values for `BooleanField`. + """ + valid_inputs = { + 'true': True, + 'false': False, + '1': True, + '0': False, + 1: True, + 0: False, + True: True, + False: False, + } + invalid_inputs = { + 'foo': ['Not a valid boolean.'], + None: ['This field may not be null.'] + } + outputs = { + 'true': True, + 'false': False, + '1': True, + '0': False, + 1: True, + 0: False, + True: True, + False: False, + 'other': True + } + field = dt.BoolField() + + def test_disallow_unhashable_collection_types(self): + inputs = ( + [], + {}, + ) + field = self.field + for input_value in inputs: + with pytest.raises(ValidationError) as exc_info: + field.run_validation(input_value,self.context) + expected = ['Not a valid boolean.'] + assert exc_info.value.detail == expected + + +class TestNullBooleanField(TestBooleanField): + """ + Valid and invalid values for `NullBooleanField`. + """ + valid_inputs = { + 'true': True, + 'false': False, + 'null': None, + True: True, + False: False, + None: None + } + invalid_inputs = { + 'foo': ['Not a valid boolean.'], + } + outputs = { + 'true': True, + 'false': False, + 'null': None, + True: True, + False: False, + 'other': True + } + field = dt.BoolField(allow_null=True) + + +class TestListField(FieldValues): + """ + Values for `ListField` with IntegerField as child. + """ + valid_inputs = [ + ([1, 2, 3], [1, 2, 3]), + (['1', '2', '3'], [1, 2, 3]), + ([], []) + ] + invalid_inputs = [ + ('not a list', ['Expected a list of items but got type "str".']), + ([1, 2, 'error', 'error'], {2: ['A valid integer is required.'], 3: ['A valid integer is required.']}), + ({'one': 'two'}, ['Expected a list of items but got type "dict".']) + ] + outputs = [ + ([1, 2, 3], [1, 2, 3]), + (['1', '2', '3'], [1, 2, 3]) + ] + field = dt.ArrayField(child=dt.IntField()) + + def test_collection_types_are_invalid_input(self): + field = dt.ArrayField(child=dt.StrField()) + input_value = ({'one': 'two'}) + + with pytest.raises(ValidationError) as exc_info: + field.deserialize(input_value,self.context) + assert exc_info.value.detail == ['Expected a list of items but got type "dict".'] + +class TestNestedArrayField(FieldValues): + """ + Values for nested `ArrayField` with IntegerField as child. + """ + valid_inputs = [ + ([[1, 2], [3]], [[1, 2], [3]]), + ([[]], [[]]) + ] + invalid_inputs = [ + (['not a list'], {0: ['Expected a list of items but got type "str".']}), + ([[1, 2, 'error'], ['error']], {0: {2: ['A valid integer is required.']}, 1: {0: ['A valid integer is required.']}}), + ([{'one': 'two'}], {0: ['Expected a list of items but got type "dict".']}) + ] + outputs = [ + ([[1, 2], [3]], [[1, 2], [3]]), + ] + field = dt.ArrayField(child=dt.ArrayField(child=dt.IntField())) + + +class TestEmptyArrayField(FieldValues): + """ + Values for `ArrayField` with allow_empty=False flag. + """ + valid_inputs = {} + invalid_inputs = [ + ([], ['This list may not be empty.']) + ] + outputs = {} + field = dt.ArrayField(child=dt.IntField(), allow_empty=False) + + +class TestArrayFieldLengthLimit(FieldValues): + valid_inputs = () + invalid_inputs = [ + ((0, 1), ['Must have at least 3 items.']), + ((0, 1, 2, 3, 4, 5), ['Must have no more than 4 items.']), + ] + outputs = () + field = dt.ArrayField(child=dt.IntField(), min_items=3, max_items=4) + + +class TestArrayFieldExactLength(FieldValues): + valid_inputs = () + invalid_inputs = [ + ((0, 1), ['Must have 3 items.']), + ] + outputs = () + field = dt.ArrayField(child=dt.IntField(),exact_items=3) + + +class TestDictField(FieldValues): + """ + Values for `DictField` with CharField as child. + """ + valid_inputs = [ + ({'a': 1, 'b': '2', 3: 3}, {'a': '1', 'b': '2', '3': '3'}), + ({}, {}), + ] + invalid_inputs = [ + ({'a': 1, 'b': None, 'c': None}, {'b': ['This field may not be null.'], 'c': ['This field may not be null.']}), + ('not a dict', ['Expected a dict of items but got type "str".']), + ] + outputs = [ + ({'a': 1, 'b': '2', 3: 3}, {'a': '1', 'b': '2', '3': '3'}), + ] + field = dt.DictField(child=dt.StrField()) + + def test_allow_null(self): + """ + If `allow_null=True` then `None` is a valid input. + """ + field = dt.DictField(allow_null=True) + output = field.run_validation(None,self.context) + assert output is None + + def test_allow_empty_disallowed(self): + """ + If allow_empty is False then an empty dict is not a valid input. + """ + field = dt.DictField(allow_empty=False) + with pytest.raises(ValidationError) as exc_info: + field.run_validation({},self.context) + + assert exc_info.value.detail == ['This dict may not be empty.'] + +class TestJSONField(FieldValues): + """ + Values for `JSONField`. + """ + valid_inputs = [ + ({ + 'a': 1, + 'b': ['some', 'list', True, 1.23], + '3': None + }, { + 'a': 1, + 'b': ['some', 'list', True, 1.23], + '3': None + }), + ] + invalid_inputs = [ + ({'a': set()}, ['Not a valid JSON.']), + ] + outputs = [ + ({ + 'a': 1, + 'b': ['some', 'list', True, 1.23], + '3': 3 + }, { + 'a': 1, + 'b': ['some', 'list', True, 1.23], + '3': 3 + }), + ] + field = dt.JSONField() + +class TestConstantField(FieldValues): + """ + Values for `ConstantField`. + """ + field = dt.ConstantField(constant="abc") + + valid_inputs = { + 'abc' : 'abc', + } + invalid_inputs = { + 'abcd' : ['Must be "abc".'] + } + outputs = { + 'abc' : 'abc' + } + + def test_disallow_null_constant(self): + field = dt.ConstantField(constant=None) + with pytest.raises(ValidationError) as exc_info: + field.run_validation({},self.context) + + assert exc_info.value.detail == ['Must be None.'] + + + +class MockFile: + def __init__(self, name='', size=0, url=''): + self.name = name + self.size = size + self.url = url + + def __eq__(self, other): + return ( + isinstance(other, MockFile) and + self.name == other.name and + self.size == other.size and + self.url == other.url + ) + + +class TestFileField(FieldValues): + """ + Values for `FileField`. + """ + valid_inputs = [ + (MockFile(name='example', size=10), MockFile(name='example', size=10)) + ] + invalid_inputs = [ + ('invalid', ['The submitted data was not a file. Check the encoding type on the form.']), + (MockFile(name='example.txt', size=0), ['The submitted file is empty.']), + (MockFile(name='', size=10), ['No filename could be determined.']), + (MockFile(name='x' * 100, size=10), ['Ensure this filename has at most 10 characters (it has 100).']) + ] + outputs = [ + (MockFile(name='example.txt', url='/example.txt'), '/example.txt'), + ('', None) + ] + field = dt.FileField(max_length=10) + + + + +class TestMethodField: + def test_method_field(self): + class ExampleSerializer(dt.Serializer): + example_field = dt.MethodField() + + def get_example_field(self, obj): + return 'ran get_example_field(%d)' % obj['example_field'] + + serializer = ExampleSerializer({'example_field': 123}) + assert serializer.data == { + 'example_field': 'ran get_example_field(123)' + } + + def test_redundant_method_name(self): + class ExampleSerializer(dt.Serializer): + example_field = dt.MethodField('get_example_field') + + field = ExampleSerializer().fields['example_field'] + assert field.method_name == 'get_example_field' + + +class MockObject: + def __init__(self, **kwargs): + self._kwargs = kwargs + for key, val in kwargs.items(): + setattr(self, key, val) + + def __str__(self): + kwargs_str = ', '.join([ + '%s=%s' % (key, value) + for key, value in sorted(self._kwargs.items()) + ]) + return '' % kwargs_str + + +class MockQueryset: + def __init__(self, iterable): + self.items = iterable + + def __getitem__(self, val): + return self.items[val] + + def get(self, **lookup): + for item in self.items: + if all([ + getattr(item, key, None) == value + for key, value in lookup.items() + ]): + return item + raise ObjectDoesNotExist() + + +class BadType: + """ + When used as a lookup with a `MockQueryset`, these objects + will raise a `TypeError`, as occurs in Django when making + queryset lookups with an incorrect type for the lookup value. + """ + def __eq__(self): + raise TypeError() + + +class TestRelatedField(APISimpleTestCase): + context = {} + def setUp(self): + self.queryset = MockQueryset([ + MockObject(pk=1, name='foo'), + MockObject(pk=2, name='bar'), + MockObject(pk=3, name='baz') + ]) + self.instance = self.queryset.items[2] + self.field = dt.RelatedField(queryset=self.queryset) + + def test_pk_related_lookup_exists(self): + instance = self.field.deserialize(self.instance.pk,self.context) + assert instance is self.instance + + def test_pk_related_lookup_does_not_exist(self): + with pytest.raises(ValidationError) as excinfo: + self.field.deserialize(4,self.context) + msg = excinfo.value.detail[0] + assert msg == 'Invalid pk "4" - object does not exist.' + + def test_pk_related_lookup_invalid_type(self): + with pytest.raises(ValidationError) as excinfo: + self.field.deserialize(BadType(),self.context) + msg = excinfo.value.detail[0] + assert msg == 'Incorrect type. Expected pk value, received BadType.' + + def test_pk_related_lookup_bool(self): + with pytest.raises(ValidationError) as excinfo: + self.field.deserialize(True,self.context) + msg = excinfo.value.detail[0] + assert msg == 'Incorrect type. Expected pk value, received bool.' + +class TestManyRelatedField(APISimpleTestCase): + context = {} + def setUp(self): + self.queryset = MockQueryset([ + MockObject(pk=1, name='foo'), + MockObject(pk=2, name='bar'), + MockObject(pk=3, name='baz') + ]) + self.child_relation = dt.RelatedField(queryset=self.queryset) + self.instance = self.queryset.items[2] + self.field = dt.ManyRelatedField(child_relation=self.child_relation,allow_empty=False) + + def test_serialize(self): + data = self.field.serialize(self.queryset,self.context) + assert data == [1,2,3] + + def test_deserialize(self): + data = self.field.deserialize( [1,2,3],self.context) + assert data == self.queryset.items + + def test_child_relation_is_list(self): + with pytest.raises(ValidationError) as excinfo: + self.field.deserialize(self.instance,self.context) + msg = excinfo.value.detail[0] + assert msg == 'Expected a list of items but got type "MockObject".' + + def test_child_relation_is_not_empty(self): + with pytest.raises(ValidationError) as excinfo: + self.field.deserialize([],self.context) + msg = excinfo.value.detail[0] + assert msg == 'This list may not be empty.' + + + + diff --git a/tests/test_parsers.py b/tests/test_parsers.py new file mode 100644 index 0000000..ddf09c4 --- /dev/null +++ b/tests/test_parsers.py @@ -0,0 +1,64 @@ +import pytest +from django.test import TestCase +from rest_framework.test import APIRequestFactory +from drf_turbo.exceptions import ParseError +import drf_turbo as dt + +class TestJSONParser(TestCase): + def setUp(self): + self.factory = APIRequestFactory() + + def test_json_parser(self): + request = self.factory.post('/', '{"a": "b"}', content_type='application/json') + assert dt.JSONParser().parse(request) == {'a': 'b'} + + def test_json_parser_with_empty_request(self): + request = self.factory.post('/', '', content_type='application/json') + with pytest.raises(ParseError): + dt.JSONParser().parse(request) + + def test_json_parser_with_invalid_json(self): + request = self.factory.post('/', '{"a": "b"', content_type='application/json') + with pytest.raises(ParseError): + dt.JSONParser().parse(request) + + + +class TestORJSONParser(TestCase): + def setUp(self): + self.factory = APIRequestFactory() + + def test_orjson_parser(self): + request = self.factory.post('/', '{"a": "b"}', content_type='application/json') + assert dt.ORJSONParser().parse(request) == {'a': 'b'} + + def test_orjson_parser_with_empty_request(self): + request = self.factory.post('/', '', content_type='application/json') + with pytest.raises(ParseError): + dt.ORJSONParser().parse(request) + + def test_orjson_parser_with_invalid_json(self): + request = self.factory.post('/', '{"a": "b"', content_type='application/json') + with pytest.raises(ParseError): + dt.ORJSONParser().parse(request) + + + +class TestUJSONParser(TestCase): + def setUp(self): + self.factory = APIRequestFactory() + + def test_ujson_parser(self): + request = self.factory.post('/', '{"a": "b"}', content_type='application/json') + assert dt.UJSONParser().parse(request) == {'a': 'b'} + + def test_ujson_parser_with_empty_request(self): + request = self.factory.post('/', '', content_type='application/json') + with pytest.raises(ParseError): + dt.UJSONParser().parse(request) + + def test_ujson_parser_with_invalid_json(self): + request = self.factory.post('/', '{"a": "b"', content_type='application/json') + with pytest.raises(ParseError): + dt.UJSONParser().parse(request) + diff --git a/tests/test_renderers.py b/tests/test_renderers.py new file mode 100644 index 0000000..c9282ef --- /dev/null +++ b/tests/test_renderers.py @@ -0,0 +1,180 @@ +import pytest +from django.test import TestCase +import drf_turbo as dt + + +class JSONRendererTests(TestCase): + """ + Tests specific to the JSON Renderer + """ + + def test_render_lazy_strings(self): + """ + JSONRenderer should deal with lazy translated strings. + """ + ret = dt.JSONRenderer().render('test') + self.assertEqual(ret, b'"test"') + + def test_render_none(self): + """ + Renderer should deal with None as ''. + """ + ret = dt.JSONRenderer().render(None) + self.assertEqual(ret, b'') + + def test_render_dict(self): + """ + Test render dict + """ + ret = dt.JSONRenderer().render({'a': 'b'}) + self.assertEqual(ret, b'{"a":"b"}') + + + def test_render_list(self): + """ + Test render list + """ + ret = dt.JSONRenderer().render(['a', 'b']) + self.assertEqual(ret, b'["a","b"]') + + + def test_render_int(self): + """ + Test render int + """ + ret = dt.JSONRenderer().render(1) + self.assertEqual(ret, b'1') + + + def test_render_float(self): + """ + Test render float + """ + ret = dt.JSONRenderer().render(1.1) + self.assertEqual(ret, b'1.1') + + def test_render_bool(self): + """ + Test render bool + """ + ret = dt.JSONRenderer().render(True) + self.assertEqual(ret, b'true') + +class ORJSONRendererTests(TestCase): + """ + Tests specific to the JSON Renderer + """ + + def test_render_lazy_strings(self): + """ + ORJSONRenderer should deal with lazy translated strings. + """ + ret = dt.ORJSONRenderer().render('test') + self.assertEqual(ret, b'"test"') + + def test_render_none(self): + """ + Renderer should deal with None as ''. + """ + ret = dt.ORJSONRenderer().render(None) + self.assertEqual(ret, b'') + + def test_render_dict(self): + """ + Test render dict + """ + ret = dt.ORJSONRenderer().render({'a': 'b'}) + self.assertEqual(ret, b'{"a":"b"}') + + + def test_render_list(self): + """ + Test render list + """ + ret = dt.ORJSONRenderer().render(['a', 'b']) + self.assertEqual(ret, b'["a","b"]') + + + def test_render_int(self): + """ + Test render int + """ + ret = dt.ORJSONRenderer().render(1) + self.assertEqual(ret, b'1') + + + def test_render_float(self): + """ + Test render float + """ + ret = dt.ORJSONRenderer().render(1.1) + self.assertEqual(ret, b'1.1') + + def test_render_bool(self): + """ + Test render bool + """ + ret = dt.ORJSONRenderer().render(True) + self.assertEqual(ret, b'true') + +class UJSONRendererTests(TestCase): + """ + Tests specific to the UJSON Renderer + """ + + def test_render_lazy_strings(self): + """ + UJSONRenderer should deal with lazy translated strings. + """ + ret = dt.UJSONRenderer().render('test') + self.assertEqual(ret, b'"test"') + + def test_render_none(self): + """ + Renderer should deal with None as ''. + """ + ret = dt.UJSONRenderer().render(None) + self.assertEqual(ret, b'') + + def test_render_dict(self): + """ + Test render dict + """ + ret = dt.UJSONRenderer().render({'a': 'b'}) + self.assertEqual(ret, b'{"a":"b"}') + + + def test_render_list(self): + """ + Test render list + """ + ret = dt.UJSONRenderer().render(['a', 'b']) + self.assertEqual(ret, b'["a","b"]') + + + def test_render_int(self): + """ + Test render int + """ + ret = dt.UJSONRenderer().render(1) + self.assertEqual(ret, b'1') + + + def test_render_float(self): + """ + Test render float + """ + ret = dt.UJSONRenderer().render(1.1) + self.assertEqual(ret, b'1.1') + + def test_render_bool(self): + """ + Test render bool + """ + ret = dt.UJSONRenderer().render(True) + self.assertEqual(ret, b'true') + + + + + \ No newline at end of file diff --git a/tests/test_response.py b/tests/test_response.py new file mode 100644 index 0000000..929aecb --- /dev/null +++ b/tests/test_response.py @@ -0,0 +1,251 @@ +from django.test import TestCase +import drf_turbo as dt + + +class TestJsonResponse(TestCase): + def test_json_response(self): + """ + Test json response + """ + resp = dt.JSONResponse({'a': 'b'}) + self.assertEqual(resp.content, b'{"a":"b"}') + + def test_json_response_with_status(self): + """ + Test json response with status + """ + resp = dt.JSONResponse({'a': 'b'}, status=400) + self.assertEqual(resp.status_code, 400) + + + def test_json_response_with_content_type(self): + """ + Test json response with content type + """ + resp = dt.JSONResponse({'a': 'b'}, content_type='application/json') + self.assertEqual(resp['Content-Type'], 'application/json') + + + def test_json_response_with_content_type_with_charset(self): + """ + Test json response with content type with charset + """ + resp = dt.JSONResponse({'a': 'b'}, content_type='application/json; charset=UTF-8') + self.assertEqual(resp['Content-Type'], 'application/json; charset=UTF-8') + + + def test_json_response_with_content_type_with_charset_and_encoding(self): + """ + Test json response with content type with charset and encoding + """ + resp = dt.JSONResponse({'a': 'b'}, content_type='application/json; charset=UTF-8; encoding=UTF-8') + self.assertEqual(resp['Content-Type'], 'application/json; charset=UTF-8; encoding=UTF-8') + + + def test_json_response_with_content_type_with_charset_and_encoding_and_encoding_errors(self): + """ + Test json response with content type with charset and encoding and encoding errors + """ + resp = dt.JSONResponse({'a': 'b'}, content_type='application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore') + self.assertEqual(resp['Content-Type'], 'application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore') + + + def test_json_response_with_content_type_with_charset_and_encoding_and_encoding_errors_and_charset_errors(self): + """ + Test json response with content type with charset and encoding and encoding errors and charset errors + """ + resp = dt.JSONResponse({'a': 'b'}, content_type='application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore; charset-errors=ignore') + self.assertEqual(resp['Content-Type'], 'application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore; charset-errors=ignore') + + + def test_json_response_with_content_type_with_charset_and_encoding_and_encoding_errors_and_charset_errors_and_indent(self): + """ + Test json response with content type with charset and encoding and encoding errors and charset errors and indent + """ + resp = dt.JSONResponse({'a': 'b'}, content_type='application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore; charset-errors=ignore; indent=4') + self.assertEqual(resp['Content-Type'], 'application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore; charset-errors=ignore; indent=4') + + + def test_json_response_with_content_type_with_charset_and_encoding_and_encoding_errors_and_charset_errors_and_indent_and_separators(self): + """ + Test json response with content type with charset and encoding and encoding errors and charset errors and indent and separators + """ + resp = dt.JSONResponse({'a': 'b'}, content_type='application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore; charset-errors=ignore; indent=4; separators=(, :)') + self.assertEqual(resp['Content-Type'], 'application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore; charset-errors=ignore; indent=4; separators=(, :)') + + + def test_json_response_with_content_type_with_charset_and_encoding_and_encoding_errors_and_charset_errors_and_indent_and_separators_and_sort_keys(self): + """ + Test json response with content type with charset and encoding and encoding errors and charset errors and indent and separators and sort keys + """ + resp = dt.JSONResponse({'a': 'b'}, content_type='application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore; charset-errors=ignore; indent=4; separators=(, :); sort-keys=true') + self.assertEqual(resp['Content-Type'], 'application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore; charset-errors=ignore; indent=4; separators=(, :); sort-keys=true') + + + +class TestORJSONResponse(TestCase): + def test_orjson_response(self): + """ + Test orjson response + """ + resp = dt.ORJSONResponse({'a': 'b'}) + self.assertEqual(resp.content, b'{"a":"b"}') + + def test_orjson_response_with_status(self): + """ + Test orjson response with status + """ + resp = dt.ORJSONResponse({'a': 'b'}, status=400) + self.assertEqual(resp.status_code, 400) + + + def test_orjson_response_with_content_type(self): + """ + Test orjson response with content type + """ + resp = dt.ORJSONResponse({'a': 'b'}, content_type='application/json') + self.assertEqual(resp['Content-Type'], 'application/json') + + + def test_orjson_response_with_content_type_with_charset(self): + """ + Test orjson response with content type with charset + """ + resp = dt.ORJSONResponse({'a': 'b'}, content_type='application/json; charset=UTF-8') + self.assertEqual(resp['Content-Type'], 'application/json; charset=UTF-8') + + def test_orjson_response_with_content_type_with_charset_and_encoding(self): + """ + Test orjson response with content type with charset and encoding + """ + resp = dt.ORJSONResponse({'a': 'b'}, content_type='application/json; charset=UTF-8; encoding=UTF-8') + self.assertEqual(resp['Content-Type'], 'application/json; charset=UTF-8; encoding=UTF-8') + + + def test_orjson_response_with_content_type_with_charset_and_encoding_and_encoding_errors(self): + """ + Test orjson response with content type with charset and encoding and encoding errors + """ + resp = dt.ORJSONResponse({'a': 'b'}, content_type='application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore') + self.assertEqual(resp['Content-Type'], 'application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore') + + def test_orjson_response_with_content_type_with_charset_and_encoding_and_encoding_errors_and_charset_errors(self): + """ + Test orjson response with content type with charset and encoding and encoding errors and charset errors + """ + resp = dt.ORJSONResponse({'a': 'b'}, content_type='application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore; charset-errors=ignore') + self.assertEqual(resp['Content-Type'], 'application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore; charset-errors=ignore') + + + def test_orjson_response_with_content_type_with_charset_and_encoding_and_encoding_errors_and_charset_errors_and_indent(self): + """ + Test orjson response with content type with charset and encoding and encoding errors and charset errors and indent + """ + resp = dt.ORJSONResponse({'a': 'b'}, content_type='application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore; charset-errors=ignore; indent=4') + self.assertEqual(resp['Content-Type'], 'application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore; charset-errors=ignore; indent=4') + + + def test_orjson_response_with_content_type_with_charset_and_encoding_and_encoding_errors_and_charset_errors_and_indent_and_separators(self): + """ + Test orjson response with content type with charset and encoding and encoding errors and charset errors and indent and separators + """ + resp = dt.ORJSONResponse({'a': 'b'}, content_type='application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore; charset-errors=ignore; indent=4; separators=(, :)') + self.assertEqual(resp['Content-Type'], 'application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore; charset-errors=ignore; indent=4; separators=(, :)') + + + def test_orjson_response_with_content_type_with_charset_and_encoding_and_encoding_errors_and_charset_errors_and_indent_and_separators_and_sort_keys(self): + resp = dt.ORJSONResponse({'a': 'b'}, content_type='application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore; charset-errors=ignore; indent=4; separators=(, :); sort-keys=true') + self.assertEqual(resp['Content-Type'], 'application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore; charset-errors=ignore; indent=4; separators=(, :); sort-keys=true') + + +class TestUJSONResponse(TestCase): + def test_ujson_response(self): + """ + Test ujson response + """ + resp = dt.UJSONResponse({'a': 'b'}) + self.assertEqual(resp.content, b'{"a":"b"}') + + def test_ujson_response_with_status(self): + """ + Test ujson response with status + """ + resp = dt.UJSONResponse({'a': 'b'}, status=400) + self.assertEqual(resp.status_code, 400) + + + def test_ujson_response_with_content_type(self): + """ + Test ujson response with content type + """ + resp = dt.UJSONResponse({'a': 'b'}, content_type='application/json') + self.assertEqual(resp['Content-Type'], 'application/json') + + + def test_ujson_response_with_content_type_with_charset(self): + """ + Test ujson response with content type with charset + """ + resp = dt.UJSONResponse({'a': 'b'}, content_type='application/json; charset=UTF-8') + self.assertEqual(resp['Content-Type'], 'application/json; charset=UTF-8') + + def test_ujson_response_with_content_type_with_charset_and_encoding(self): + """ + Test ujson response with content type with charset and encoding + """ + resp = dt.UJSONResponse({'a': 'b'}, content_type='application/json; charset=UTF-8; encoding=UTF-8') + self.assertEqual(resp['Content-Type'], 'application/json; charset=UTF-8; encoding=UTF-8') + + + def test_ujson_response_with_content_type_with_charset_and_encoding_and_encoding_errors(self): + """ + Test ujson response with content type with charset and encoding and encoding errors + """ + resp = dt.UJSONResponse({'a': 'b'}, content_type='application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore') + self.assertEqual(resp['Content-Type'], 'application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore') + + def test_ujson_response_with_content_type_with_charset_and_encoding_and_encoding_errors_and_charset_errors(self): + """ + Test ujson response with content type with charset and encoding and encoding errors and charset errors + """ + resp = dt.UJSONResponse({'a': 'b'}, content_type='application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore; charset-errors=ignore') + self.assertEqual(resp['Content-Type'], 'application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore; charset-errors=ignore') + + def test_ujson_response_with_content_type_with_charset_and_encoding_and_encoding_errors_and_charset_errors_and_indent(self): + """ + Test ujson response with content type with charset and encoding and encoding errors and charset errors and indent + """ + resp = dt.UJSONResponse({'a': 'b'}, content_type='application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore; charset-errors=ignore; indent=4') + self.assertEqual(resp['Content-Type'], 'application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore; charset-errors=ignore; indent=4') + + + def test_ujson_response_with_content_type_with_charset_and_encoding_and_encoding_errors_and_charset_errors_and_indent_and_separators(self): + """ + Test ujson response with content type with charset and encoding and encoding errors and charset errors and indent and separators + """ + resp = dt.UJSONResponse({'a': 'b'}, content_type='application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore; charset-errors=ignore; indent=4; separators=(, :)') + self.assertEqual(resp['Content-Type'], 'application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore; charset-errors=ignore; indent=4; separators=(, :)') + + def test_ujson_response_with_content_type_with_charset_and_encoding_and_encoding_errors_and_charset_errors_and_indent_and_separators_and_sort_keys(self): + resp = dt.UJSONResponse({'a': 'b'}, content_type='application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore; charset-errors=ignore; indent=4; separators=(, :); sort-keys=true') + self.assertEqual(resp['Content-Type'], 'application/json; charset=UTF-8; encoding=UTF-8; encoding-errors=ignore; charset-errors=ignore; indent=4; separators=(, :); sort-keys=true') + + +class TestSuccessResponse(TestCase): + def test_success_response(self): + """ + Test success response + """ + resp = dt.SuccessResponse({'a': 'b'}) + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.content, b'{"message":"Success","data":{"a":"b"},"error":false}') + + +class TestErrorResponse(TestCase): + def test_error_response(self): + """ + Test error response + """ + resp = dt.ErrorResponse({'a': 'b'}) + self.assertEqual(resp.status_code, 400) + self.assertEqual(resp.content, b'{"message":"Bad request","data":{"a":"b"},"error":true}') diff --git a/tests/test_serializer.py b/tests/test_serializer.py new file mode 100644 index 0000000..baefe64 --- /dev/null +++ b/tests/test_serializer.py @@ -0,0 +1,442 @@ +import pytest +from collections import ChainMap +from collections.abc import Mapping +import drf_turbo as dt +import pickle +import re +from drf_turbo.exceptions import OnlyAndExcludeError + + +class TestSerializer: + def setup(self): + class ExampleSerializer(dt.Serializer): + char = dt.StrField() + integer = dt.IntField() + self.Serializer = ExampleSerializer + self.context = {} + + def test_valid_serializer(self): + serializer = self.Serializer(data={'char': 'abc', 'integer': 123}) + assert serializer.is_valid() + assert serializer.validated_data == {'char': 'abc', 'integer': 123} + assert serializer.data == {'char': 'abc', 'integer': 123} + assert serializer.errors == {} + + def test_invalid_serializer(self): + serializer = self.Serializer(data={'char': 'abc'}) + assert not serializer.is_valid() + assert serializer.validated_data == {} + assert serializer.data == {'char': 'abc'} + assert serializer.errors == {'integer': ['This field is required.']} + + def test_invalid_datatype(self): + serializer = self.Serializer(data=[{'char': 'abc'}]) + assert not serializer.is_valid() + assert serializer.validated_data == {} + # assert serializer.data == {} + assert serializer.errors == ['Invalid data type: list'] + + def test_partial_validation(self): + serializer = self.Serializer(data={'char': 'abc'}, partial=True) + assert serializer.is_valid() + assert serializer.validated_data == {'char': 'abc'} + assert serializer.errors == {} + + def test_empty_serializer(self): + serializer = self.Serializer() + assert serializer.data == {'char': '', 'integer': None} + + def test_missing_attribute_during_serialization(self): + class MissingAttributes: + pass + instance = MissingAttributes() + serializer = self.Serializer(instance) + with pytest.raises(AttributeError): + serializer.data + + def test_data_access_before_save_raises_error(self): + def create(validated_data): + return validated_data + serializer = self.Serializer(data={'char': 'abc', 'integer': 123}) + serializer.create = create + assert serializer.is_valid() + assert serializer.data == {'char': 'abc', 'integer': 123} + with pytest.raises(AssertionError): + serializer.save() + + def test_validate_none_data(self): + data = None + serializer = self.Serializer(data=data) + assert not serializer.is_valid() + assert serializer.errors == ['Invalid data type: NoneType'] + + def test_serialize_chainmap(self): + data = ChainMap({'char': 'abc'}, {'integer': 123}) + serializer = self.Serializer(data=data) + assert serializer.is_valid() + assert serializer.validated_data == {'char': 'abc', 'integer': 123} + assert serializer.errors == {} + + def test_serialize_custom_mapping(self): + class SinglePurposeMapping(Mapping): + def __getitem__(self, key): + return 'abc' if key == 'char' else 123 + + def __iter__(self): + yield 'char' + yield 'integer' + + def __len__(self): + return 2 + + serializer = self.Serializer(data=SinglePurposeMapping()) + assert serializer.is_valid() + assert serializer.validated_data == {'char': 'abc', 'integer': 123} + assert serializer.errors == {} + + def test_custom_deserialize(self): + """ + deserialize() is expected to return a dict, but subclasses may + return application specific type. + """ + class Point: + def __init__(self, srid, x, y): + self.srid = srid + self.coords = (x, y) + + # Declares a serializer that converts data into an object + class NestedPointSerializer(dt.Serializer): + longitude = dt.FloatField(attr='x') + latitude = dt.FloatField(attr='y') + + def deserialize(self, data,context): + kwargs = super().deserialize(data,context) + return Point(srid=4326, **kwargs) + + serializer = NestedPointSerializer(data={'longitude': 6.958307, 'latitude': 50.941357}) + assert serializer.is_valid() + assert isinstance(serializer.validated_data, Point) + assert serializer.validated_data.srid == 4326 + assert serializer.validated_data.coords[0] == 6.958307 + assert serializer.validated_data.coords[1] == 50.941357 + assert serializer.errors == {} + + def test_iterable_validators(self): + """ + Ensure `validators` parameter is compatible with reasonable iterables. + """ + data = {'char': 'abc', 'integer': 123} + + for validators in ([], (), set()): + class ExampleSerializer(dt.Serializer): + char = dt.StrField(validators=validators) + integer = dt.IntField() + + serializer = ExampleSerializer(data=data) + assert serializer.is_valid() + assert serializer.validated_data == data + assert serializer.errors == {} + + def raise_exception(value): + raise dt.ValidationError('Raised error') + + for validators in ([raise_exception], (raise_exception,), {raise_exception}): + class ExampleSerializer(dt.Serializer): + char = dt.StrField(validators=validators) + integer = dt.IntField() + + serializer = ExampleSerializer(data=data) + assert not serializer.is_valid() + assert serializer.data == data + assert serializer.validated_data == {} + + def test_only_fields(self): + class ExampleSerializer(dt.Serializer): + char = dt.StrField() + integer = dt.IntField() + + serializer = ExampleSerializer({'char': 'abc', 'integer': 123}, only=('char',)) + assert serializer.data == {'char': 'abc'} + + + def test_exclude_fields(self): + class ExampleSerializer(dt.Serializer): + char = dt.StrField() + integer = dt.IntField() + + serializer = ExampleSerializer({'char': 'abc', 'integer': 123}, exclude=('char',)) + assert serializer.data == {'integer': 123} + + def test_not_allowed_only_and_exclude_fields(self): + class ExampleSerializer(dt.Serializer): + char = dt.StrField() + integer = dt.IntField() + + with pytest.raises(OnlyAndExcludeError): + + ExampleSerializer({'char': 'abc', 'integer': 123}, only=('char',), exclude=('integer',)) + + + def test_nested_only_fields(self): + class ExampleSerializer(dt.Serializer): + char = dt.StrField() + integer = dt.IntField() + + class NestedSerializer(dt.Serializer): + example = ExampleSerializer() + + serializer = NestedSerializer({'example': {'char': 'abc', 'integer': 123}}, only=('example__char',)) + assert serializer.data == {'example': {'char': 'abc'}} + + + def test_nested_exclude_fields(self): + class ExampleSerializer(dt.Serializer): + char = dt.StrField() + integer = dt.IntField() + + class NestedSerializer(dt.Serializer): + example = ExampleSerializer() + + serializer = NestedSerializer({'example': {'char': 'abc', 'integer': 123}}, exclude=('example__char',)) + assert serializer.data == {'example': {'integer': 123}} + + + def test_nested_only_inheritance(self): + class ExampleSerializer(dt.Serializer): + char = dt.StrField() + integer = dt.IntField() + other = dt.StrField() + + class NestedSerializer(dt.Serializer): + foo = dt.StrField() + bar = dt.IntField() + example = ExampleSerializer(only=('char','other')) + + serializer = NestedSerializer({'foo' : 'foo','bar' : 123, 'example': {'char': 'abc', 'integer': 123 , 'other' : 'other'}}, only=('foo','example__other')) + assert serializer.data == {'example': {'other': 'other'}, 'foo': 'foo'} + + + def test_nested_exclude_inheritance(self): + class ExampleSerializer(dt.Serializer): + char = dt.StrField() + integer = dt.IntField() + other = dt.StrField() + + class NestedSerializer(dt.Serializer): + foo = dt.StrField() + bar = dt.IntField() + example = ExampleSerializer(exclude=('char',)) + + serializer = NestedSerializer({'foo' : 'foo','bar' : 123, 'example': {'char': 'abc', 'integer': 123 , 'other' : 'other'}}, exclude=('foo','example__other')) + assert serializer.data == {'example': {'integer': 123}, 'bar': 123} + + + +class TestValidateMethod: + def test_non_field_error_validate_method(self): + class ExampleSerializer(dt.Serializer): + char = dt.StrField() + integer = dt.IntField() + + def validate(self, attrs): + raise dt.ValidationError('Non field error') + + serializer = ExampleSerializer(data={'char': 'abc', 'integer': 123}) + assert not serializer.is_valid() + assert serializer.errors == ['Non field error'] + + def test_field_error_validate_method(self): + class ExampleSerializer(dt.Serializer): + char = dt.StrField() + integer = dt.IntField() + + def validate(self, attrs): + raise dt.ValidationError({'char': 'Field error'}) + + serializer = ExampleSerializer(data={'char': 'abc', 'integer': 123}) + assert not serializer.is_valid() + assert serializer.errors == {'char': 'Field error'} + + +class MockObject: + def __init__(self, **kwargs): + self._kwargs = kwargs + for key, val in kwargs.items(): + setattr(self, key, val) + + def __str__(self): + kwargs_str = ', '.join([ + '%s=%s' % (key, value) + for key, value in sorted(self._kwargs.items()) + ]) + return '' % kwargs_str + + + + +class TestNotRequiredOutput: + def test_not_required_output_for_dict(self): + """ + 'required=False' should allow a dictionary key to be missing in output. + """ + class ExampleSerializer(dt.Serializer): + omitted = dt.StrField(required=False) + included = dt.StrField() + + serializer = ExampleSerializer(data={'included': 'abc'}) + serializer.is_valid() + assert serializer.data == {'included': 'abc'} + + def test_not_required_output_for_object(self): + """ + 'required=False' should allow an object attribute to be missing in output. + """ + class ExampleSerializer(dt.Serializer): + omitted = dt.StrField(required=False) + included = dt.StrField() + + def create(self, validated_data): + return MockObject(**validated_data) + + serializer = ExampleSerializer(data={'included': 'abc'}) + serializer.is_valid() + serializer.save() + assert serializer.data == {'included': 'abc'} + + +class TestDefaultOutput: + def setup(self): + class ExampleSerializer(dt.Serializer): + has_default = dt.StrField(default_value='x',required=False) + has_default_callable = dt.StrField(default_value=lambda: 'y',required=False) + no_default = dt.StrField() + self.Serializer = ExampleSerializer + + def test_default_used_for_dict(self): + """ + 'default="something"' should be used if dictionary key is missing from input. + """ + serializer = self.Serializer({'no_default': 'abc'}) + assert serializer.data == {'has_default': 'x', 'has_default_callable': 'y', 'no_default': 'abc'} + + def test_default_used_for_object(self): + """ + 'default="something"' should be used if object attribute is missing from input. + """ + instance = MockObject(no_default='abc') + serializer = self.Serializer(instance) + assert serializer.data == {'has_default': 'x', 'has_default_callable': 'y', 'no_default': 'abc'} + + def test_default_not_used_when_in_dict(self): + """ + 'default="something"' should not be used if dictionary key is present in input. + """ + serializer = self.Serializer({'has_default': 'def', 'has_default_callable': 'ghi', 'no_default': 'abc'}) + assert serializer.data == {'has_default': 'def', 'has_default_callable': 'ghi', 'no_default': 'abc'} + + def test_default_not_used_when_in_object(self): + """ + 'default="something"' should not be used if object attribute is present in input. + """ + instance = MockObject(has_default='def', has_default_callable='ghi', no_default='abc') + serializer = self.Serializer(instance) + assert serializer.data == {'has_default': 'def', 'has_default_callable': 'ghi', 'no_default': 'abc'} + + + def test_default_for_dotted_source(self): + """ + 'default="something"' should be used when a traversed attribute is missing from input. + """ + class Serializer(dt.Serializer): + traversed = dt.StrField(default_value='x', attr='traversed.attr',required=False) + + assert Serializer({}).data == {'traversed': 'x'} + assert Serializer({'traversed': {}}).data == {'traversed': 'x'} + assert Serializer({'traversed': None}).data == {'traversed': 'x'} + + assert Serializer({'traversed': {'attr': 'abc'}}).data == {'traversed': 'abc'} + + + def test_default_for_nested_serializer(self): + class NestedSerializer(dt.Serializer): + a = dt.StrField(default_value='1',required=False) + c = dt.StrField(default_value='2', attr='b.c',required=False) + + class Serializer(dt.Serializer): + nested = NestedSerializer() + + assert Serializer({'nested': None}).data == {'nested': None} + assert Serializer({'nested': {}}).data == {'nested': {'a': '1', 'c': '2'}} + assert Serializer({'nested': {'a': '3', 'b': {}}}).data == {'nested': {'a': '3', 'c': '2'}} + assert Serializer({'nested': {'a': '3', 'b': {'c': '4'}}}).data == {'nested': {'a': '3', 'c': '4'}} + + def test_default_for_allow_null(self): + """ + Without an explicit default, allow_null implies default=None when serializing. #5518 #5708 + """ + class Serializer(dt.Serializer): + foo = dt.StrField() + bar = dt.StrField(attr='foo.bar', allow_null=True) + optional = dt.StrField(required=False, allow_null=True) + + # allow_null=True should imply default=None when serializing: + assert Serializer({'foo': None}).data == {'foo': None, 'bar': None, 'optional': None, } + + +class TestCacheSerializerData: + def test_cache_serializer_data(self): + """ + Caching serializer data with pickle will drop the serializer info, + but does preserve the data itself. + """ + class ExampleSerializer(dt.Serializer): + field1 = dt.StrField() + field2 = dt.StrField() + + serializer = ExampleSerializer({'field1': 'a', 'field2': 'b'}) + pickled = pickle.dumps(serializer.data) + data = pickle.loads(pickled) + assert data == {'field1': 'a', 'field2': 'b'} + + +class TestDefaultInclusions: + def setup(self): + class ExampleSerializer(dt.Serializer): + char = dt.StrField(default_value='abc',required=False) + integer = dt.IntField() + self.Serializer = ExampleSerializer + + def test_default_should_included_on_create(self): + serializer = self.Serializer(data={'integer': 456}) + assert serializer.is_valid() + assert serializer.validated_data == {'char': 'abc', 'integer': 456} + assert serializer.errors == {} + + def test_default_should_be_included_on_update(self): + instance = MockObject(char='def', integer=123) + serializer = self.Serializer(instance, data={'integer': 456}) + assert serializer.is_valid() + assert serializer.validated_data == {'char': 'abc', 'integer': 456} + assert serializer.errors == {} + + def test_default_should_not_be_included_on_partial_update(self): + instance = MockObject(char='def', integer=123) + serializer = self.Serializer(instance, data={'integer': 456}, partial=True) + assert serializer.is_valid() + assert serializer.validated_data == {'integer': 456} + assert serializer.errors == {} + + +class TestSerializerValidationWithCompiledRegexField: + def setup(self): + class ExampleSerializer(dt.Serializer): + name = dt.RegexField(re.compile(r'\d'), required=True) + self.Serializer = ExampleSerializer + + def test_validation_success(self): + serializer = self.Serializer(data={'name': '2'}) + assert serializer.is_valid() + assert serializer.validated_data == {'name': '2'} + assert serializer.errors == {} + + diff --git a/tox.ini b/tox.ini new file mode 100755 index 0000000..1f3f528 --- /dev/null +++ b/tox.ini @@ -0,0 +1,19 @@ +[tox] +envlist = py36, py37, py38, flake8 + +[travis] +python = + 3.8: py38 + 3.7: py37 + 3.6: py36 + +[testenv:flake8] +basepython = python +deps = flake8 +commands = flake8 drf_turbo tests + +[testenv] +setenv = + PYTHONPATH = {toxinidir} + +commands = python setup.py test