Skip to content

Commit

Permalink
refactor: Update docs and add tests
Browse files Browse the repository at this point in the history
This renames the tests to be picked up by test.sh. It also adds a test for bulk_insert_models
  • Loading branch information
taylor-cedar committed Jul 13, 2021
1 parent 5575da9 commit 3ef7cff
Show file tree
Hide file tree
Showing 15 changed files with 396 additions and 35 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__pycache__
91 changes: 89 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,93 @@
# Django Bulk Load
Load large sets of Django models into the DB using Postgres COPY command. This is a more performant alternative
to bulk_insert and bulk_upsert in Django.
to [bulk_create](https://docs.djangoproject.com/en/3.2/ref/models/querysets/#bulk-create) and
[bulk_update](https://docs.djangoproject.com/en/3.2/ref/models/querysets/#bulk-update) in Django.

Note: Currently, this library only supports Postgres. Other databases may be added in the future

## Install
```shell
pip install django-bulk-load
```

## API
See `bulk_load.py` for full API
Just import and use the functions below. No need to change settings.py

### bulk_insert_models()
INSERT a batch of models. It makes use of Postgres COPY command to improve speed. If a row already exist, the entire
insert will fail.

```python
from django_bulk_load import bulk_insert_models

bulk_insert_models(
models: Sequence[Model],
ignore_conflicts: bool = False,
return_models: bool = False,
)
```

### bulk_upsert_models()
UPSERT a batch of models. Replicates [UPSERTing](https://wiki.postgresql.org/wiki/UPSERT) for a large set of models.
By default, it matches existing models using the model `pk`, but you can specify matching on other fields with
`pk_field_names`. See bulk_load.py for descriptions of other parameters.

```python
from django_bulk_load import bulk_upsert_models

bulk_upsert_models(
models: Sequence[Model],
pk_field_names: Sequence[str] = None,
insert_only_field_names: Sequence[str] = None,
model_changed_field_names: Sequence[str] = None,
update_if_null_field_names: Sequence[str] = None,
return_models: bool = False,
)
```

### bulk_update_models()
UPDATE a batch of models. If the model is not found in the database, it is ignored. See bulk_load.py for descriptions of other parameters.

```python
from django_bulk_load import bulk_update_models

bulk_update_models(
models: Sequence[Model],
update_field_names: Sequence[str] = None,
pk_field_names: Sequence[str] = None,
model_changed_field_names: Sequence[str] = None,
update_if_null_field_names: Sequence[str] = None,
return_models: bool = False,
)
```

### bulk_insert_changed_models()
INSERTs a new record in the database when a model field has changed in any of `compare_field_names`,
with respect to its latest state, where "latest" is defined by ordering the records
for a given primary key by sorting in descending order on the column passed in
`order_field_name`. Does not INSERT a new record if the latest record has not changed.

```python
from django_bulk_load import bulk_update_models
bulk_insert_changed_models(
models: Sequence[Model],
pk_field_names: Sequence[str],
compare_field_names: Sequence[str],
order_field_name=None,
return_models=None,
)
```

### bulk_select_model_dicts()
Select/Get model dictionaries by filter_field_names. It returns dictionaries, not Django
models for performance reasons. This is useful when querying a very large set of models or multiple field IN clauses.

```python
bulk_select_model_dicts(
model_class: Type[Model],
filter_field_names: Iterable[str],
select_field_names: Iterable[str],
filter_data: Iterable[Sequence],
skip_filter_transform=False,
)
```
2 changes: 2 additions & 0 deletions django_bulk_load/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
bulk_insert_changed_models,
bulk_load_models_with_queries,
bulk_select_model_dicts,
bulk_insert_models,
bulk_update_models,
bulk_upsert_models,
)

__all__ = [
"bulk_select_model_dicts",
"bulk_insert_models",
"bulk_update_models",
"bulk_upsert_models",
"bulk_insert_changed_models",
Expand Down
98 changes: 81 additions & 17 deletions django_bulk_load/bulk_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from django.db.backends.utils import CursorWrapper
from django.db.models import AutoField, Model
from psycopg2.extras import execute_values
from psycopg2.sql import SQL, Composable, Identifier
from psycopg2.sql import Composable

from .django import (
django_field_to_query_value,
Expand All @@ -23,17 +23,19 @@
create_temp_table,
generate_insert_on_not_match_latest,
generate_insert_query,
generate_insert_for_update_query,
generate_select_latest,
generate_select_query,
generate_update_query,
generate_values_select_query,
copy_query
)
from .utils import generate_table_name

logger = logging.getLogger(__name__)


def create_table_and_load(
def create_temp_table_and_load(
models: Sequence[Model],
connection: BaseDatabaseWrapper,
cursor: CursorWrapper,
Expand All @@ -57,9 +59,7 @@ def create_table_and_load(
tsv_buffer = models_to_tsv_buffer(models, fields, connection=connection)
cursor.execute(temp_table_query)
cursor.copy_expert(
SQL("COPY {table_name} FROM STDIN NULL '\\N' DELIMITER '\t' CSV").format(
table_name=Identifier(table_name)
),
copy_query(table_name),
tsv_buffer,
)

Expand Down Expand Up @@ -110,7 +110,7 @@ def bulk_load_models_with_queries(
"Starting loading models",
extra=dict(model_count=len(models), table_name=table_name),
)
loading_table_name = create_table_and_load(
loading_table_name = create_temp_table_and_load(
models=models,
table_name=loading_table_name,
field_names=field_names,
Expand Down Expand Up @@ -141,6 +141,66 @@ def bulk_load_models_with_queries(
return results


def bulk_insert_models(
models: Sequence[Model],
ignore_conflicts: bool = False,
return_models: bool = False,
):
"""
INSERT a batch of models. It makes use of Postgres COPY command to improve speed. If a row already exist, the entire
insert will fail.
:param models: Django model list/tuple
:param ignore_conflicts: If there is an error on a unique constrain, ignore instead of erroring
:param return_models: Query and return the models in the DB, whether updated or not.
Defaults to False, since this can significantly degrade performance
:return: None or List[Model] depending upon returns_models param. Returns all models passed in,
not just ones updated or inserted. Models will not be in the same order they were passed in
"""
if not models:
logger.warning("No models passed to bulk_insert_models")
return [] if return_models else None

# Verify the models either all have pk set or all don't. It's an issue when it's a mix
# because we have to specify a list of fields to insert. We need to ignore the PK field if it's not
# set, but if it is set, you need to add it to the list of fields. Adding the field causes
# NULL errors for any models that don't have the PK set.
has_pks = None
for model in models:
models_has_pks = model.pk is not None
if has_pks is None:
has_pks = models_has_pks

if has_pks != models_has_pks:
raise ValueError(
"Mix of models with PK and no PK specified. This can cause issues. Split into 2 groups instead"
)
model_meta = models[0]._meta
table_name = model_meta.db_table

loading_table_name = generate_table_name(table_name)

insert_fields = get_model_fields(model_meta, include_auto_fields=has_pks)
insert_query = generate_insert_query(
table_name=table_name,
loading_table_name=loading_table_name,
ignore_conflicts=ignore_conflicts,
insert_fields=insert_fields,
)

if return_models:
# Since we want to return ALL models (not just the ones actually inserted), we
# need to run an additional select on all of the models in the loading table
insert_query = add_returning(insert_query, table_name=table_name)

return bulk_load_models_with_queries(
models=models,
loading_table_name=loading_table_name,
load_queries=[insert_query],
return_models=return_models,
)


def bulk_update_models(
models: Sequence[Model],
update_field_names: Sequence[str] = None,
Expand All @@ -150,8 +210,8 @@ def bulk_update_models(
return_models: bool = False,
):
"""
Update a batch of models. This is useful for updating a large set of models at once. It makes use
of Postgres COPY command to improve speed
UPDATE a batch of models. If the model is not found in the database, it is ignored.
:param models: Django model list/tuple
:param update_field_names: Field to update (defaults to all fields)
:param pk_field_names: Fields used to match existing models in the DB. By default uses model primary key.
Expand Down Expand Up @@ -255,8 +315,10 @@ def bulk_upsert_models(
return_models: bool = False,
):
"""
Upsert a batch of models. This is useful for updating a large set of models at once. It makes use
of Postgres COPY command to improve speed
UPSERT a batch of models. Replicates [UPSERTing](https://wiki.postgresql.org/wiki/UPSERT) for a large set of models.
By default, it matches existing models using the model `pk`, but you can specify matching on other fields with
`pk_field_names`.
:param models: Django model list/tuple
:param insert_only_field_names: Names of model fields to only insert, never update (i.e. created_on)
:param pk_field_names: Fields used to match existing models in the DB. By default uses model primary key.
Expand Down Expand Up @@ -313,7 +375,7 @@ def bulk_upsert_models(
)
)

insert_query = generate_insert_query(
insert_query = generate_insert_for_update_query(
table_name=table_name,
loading_table_name=loading_table_name,
insert_fields=insert_fields,
Expand Down Expand Up @@ -349,10 +411,11 @@ def bulk_insert_changed_models(
return_models=None,
):
"""
Inserts a new record when the model has changed in any of `compare_field_names`,
INSERTs a new record in the database when a model field has changed in any of `compare_field_names`,
with respect to its latest state, where "latest" is defined by ordering the records
for a given primary key by sorting in descending order on the column passed in
`order_field_name`. Otherwise, does nothing.
`order_field_name`. Does not INSERT a new record if the latest record has not changed.
:param models: Django model list/tuple
:param pk_field_names: Fields used to match existing models in the DB. By default uses model primary key.
:param order_field_name: Field to determine the latest record (normally an AutoField or last_modified datetime type field)
Expand Down Expand Up @@ -431,13 +494,14 @@ def bulk_select_model_dicts(
skip_filter_transform=False,
) -> List[Dict]:
"""
Select/Get model dictionaries by filter_field_names. This is useful when
querying a very large set of models or multiple value IN clauses
Select/Get model dictionaries by filter_field_names. It returns dictionaries, not Django
models for performance reasons. This is useful when querying a very large set of models
or multiple field IN clauses.
:param model_class: Model class to query. For instance django.contrib.auth.models.User
:param filter_field_names: Fields to use in the query.
:param select_field_names: The fields to return in the result dictionaries. The dictionaries will always contain
the model_filter keys in addition to any fields in select_field_names
the filter_field_names keys in addition to any fields in select_field_names
:param filter_data: Values (normally tuples) of the filter_field_names. For instance if filter_field_names=["field1", "field2"],
filter_data may be [(12, "hello"), (23, "world"), (35, "fun"), ...]
:param skip_filter_transform: Normally the function converts the filter_data into DB specific values. This is useful
Expand Down Expand Up @@ -500,7 +564,7 @@ def bulk_select_model_dicts(
{
select_field_map[column]
.attname: select_field_map[column]
.from_db_value(value, connection=connection)
.from_db_value(value, expression=None, connection=connection)
if hasattr(select_field_map[column], "from_db_value")
else value
for column, value in zip(columns, row)
Expand Down
10 changes: 4 additions & 6 deletions django_bulk_load/django.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from django.db.models.options import Options
from psycopg2.extras import Json

from django_json_text_field.fields import SerializedField

from .utils import NULL_CHARACTER


Expand All @@ -30,9 +28,9 @@ def records_to_models(
field.column: DjangoFieldInfo(
field_name=field.attname,
deserializer=(
field.deserializer
if isinstance(field, SerializedField)
else lambda x: x
field.from_db_value
if hasattr(field, "from_db_value")
else lambda value, expression, connection: value
),
)
for field in get_model_fields(model_class._meta, include_auto_fields=True)
Expand All @@ -43,7 +41,7 @@ def records_to_models(
zipped_results = dict(zip(columns, row))
attrs = {
django_field_info.field_name: (
django_field_info.deserializer(zipped_results[column])
django_field_info.deserializer(zipped_results[column], None, None)
if column in zipped_results
else models.DEFERRED
)
Expand Down
36 changes: 36 additions & 0 deletions django_bulk_load/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ def create_temp_table(temp_table_name, source_table_name, column_names):
),
)

def copy_query(table_name: str):
return SQL("COPY {table_name} FROM STDIN NULL '\\N' DELIMITER '\t' CSV").format(
table_name=Identifier(table_name)
)

def add_returning(query: Composable, table_name: str) -> Composable:
return Composed(
Expand Down Expand Up @@ -83,6 +87,37 @@ def generate_distinct_null_condition(


def generate_insert_query(
table_name: str,
loading_table_name: str,
ignore_conflicts: bool,
insert_fields: Sequence[models.Field],
):
query = SQL(
"INSERT INTO {table_name} ({insert_column_list}) "
"SELECT {select_column_list} FROM {loading_table_name}"
).format(
table_name=Identifier(table_name),
insert_column_list=SQL(", ").join(
Identifier(field.column) for field in insert_fields
),
select_column_list=SQL(", ").join(
SQL(".").join((Identifier(loading_table_name), Identifier(x.column)))
for x in insert_fields
),
loading_table_name=Identifier(loading_table_name),
)

if ignore_conflicts:
return Composed(
[
query,
SQL(" ON CONFLICT DO NOTHING"),
]
)

return query

def generate_insert_for_update_query(
table_name: str,
loading_table_name: str,
pk_fields: Sequence[models.Field],
Expand All @@ -109,6 +144,7 @@ def generate_insert_query(
)



def generate_select_latest(table_name, loading_table_name, pk_fields, order_field):
join_clause = generate_join_condition(loading_table_name, table_name, pk_fields)
return SQL(
Expand Down
2 changes: 1 addition & 1 deletion test.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
#!/usr/bin/env bash
docker-compose run --rm test ./manage.py test --snapshot-update
docker-compose run --rm test ./manage.py test
File renamed without changes.
Loading

0 comments on commit 3ef7cff

Please sign in to comment.