Skip to content

Commit

Permalink
feat: add ability to specify custom where clause in bulk_update_model…
Browse files Browse the repository at this point in the history
…s and bulk_upsert_models
  • Loading branch information
taylor-cedar committed Nov 15, 2023
1 parent c9ec366 commit 0dc7a65
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 23 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ bulk_upsert_models(
insert_only_field_names: Sequence[str] = None,
model_changed_field_names: Sequence[str] = None,
update_if_null_field_names: Sequence[str] = None,
update_where: Callable[[Sequence[Field], str, str], Composable] = None,
return_models: bool = False,
)
```
Expand All @@ -141,6 +142,7 @@ bulk_update_models(
pk_field_names: Sequence[str] = None,
model_changed_field_names: Sequence[str] = None,
update_if_null_field_names: Sequence[str] = None,
update_where: Callable[[Sequence[Field], str, str], Composable] = None,
return_models: bool = False,
)
```
Expand Down Expand Up @@ -175,6 +177,7 @@ bulk_select_model_dicts(
filter_field_names: Iterable[str],
select_field_names: Iterable[str],
filter_data: Iterable[Sequence],
select_for_update=False,
skip_filter_transform=False,
)
```
Expand Down
7 changes: 7 additions & 0 deletions django_bulk_load/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,18 @@
bulk_upsert_models,
)

from .queries import (
generate_distinct_condition,
generate_greater_than_condition
)

__all__ = [
"bulk_select_model_dicts",
"bulk_insert_models",
"bulk_update_models",
"bulk_upsert_models",
"bulk_insert_changed_models",
"bulk_load_models_with_queries",
"generate_distinct_condition",
"generate_greater_than_condition"
]
14 changes: 11 additions & 3 deletions django_bulk_load/bulk_load.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import logging
from time import monotonic
from typing import Dict, Iterable, List, Optional, Sequence, Type
from typing import Dict, Iterable, List, Optional, Sequence, Type, Callable

from django.db import connections, router, transaction
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.backends.utils import CursorWrapper
from django.db.models import AutoField, Model
from django.db.models import AutoField, Model, Field
from psycopg2.extras import execute_values
from psycopg2.sql import Composable
from psycopg2.sql import Composable, SQL

from .django import (
django_field_to_query_value,
Expand Down Expand Up @@ -207,6 +207,7 @@ def bulk_update_models(
pk_field_names: Sequence[str] = None,
model_changed_field_names: Sequence[str] = None,
update_if_null_field_names: Sequence[str] = None,
update_where: Callable[[Sequence[Field], str, str], Composable] = None,
return_models: bool = False,
):
"""
Expand All @@ -219,6 +220,8 @@ def bulk_update_models(
list is changed) (i.e. update_on/last_modified)
:param update_if_null_field_names: Fields that only get updated if the new value is NULL or existing
value in the DB is NULL.
:param update_where: Function that returns a Composable that is used to filter the update query. Should not be used
with model_changed_field_names or update_if_null_field_names (can lead to unexpected behavior)
:param return_models: Query and return the models in the DB, whether updated or not.
Defaults to False, since this can significantly degrade performance
:return: None or List[Model] depending upon returns_models param. Returns all models passed in,
Expand Down Expand Up @@ -284,6 +287,7 @@ def bulk_update_models(
update_if_null_field_names, model_meta
),
pk_fields=pk_fields,
update_where=update_where,
loading_table_name=loading_table_name,
)

Expand Down Expand Up @@ -312,6 +316,7 @@ def bulk_upsert_models(
insert_only_field_names: Sequence[str] = None,
model_changed_field_names: Sequence[str] = None,
update_if_null_field_names: Sequence[str] = None,
update_where: Callable[[Sequence[Field], str, str], Composable] = None,
return_models: bool = False,
):
"""
Expand All @@ -326,6 +331,8 @@ def bulk_upsert_models(
list is changed) (i.e. update_on/last_modified)
:param update_if_null_field_names: Fields that only get updated if the new value is NULL or existing
value in the DB is NULL.
:param update_where: Function that returns a Composable that is used to filter the update query. Cannot be used
with model_changed_field_names or update_if_null_field_names
:param return_models: Query and return the models in the DB, whether updated or not.
Defaults to False, since this can significantly degrade performance
:return: None or List[Model] depending upon returns_models param. Returns all models passed in,
Expand Down Expand Up @@ -368,6 +375,7 @@ def bulk_upsert_models(
compare_fields=compare_fields,
update_fields=update_fields,
pk_fields=pk_fields,
update_where=update_where,
update_if_null_fields=get_fields_from_names(
update_if_null_field_names, model_meta
),
Expand Down
54 changes: 36 additions & 18 deletions django_bulk_load/queries.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Sequence
from typing import Sequence, Callable

from django.db import models
from psycopg2.sql import SQL, Composable, Composed, Identifier
Expand Down Expand Up @@ -55,7 +55,6 @@ def generate_distinct_condition(
conditions = []
for field in compare_fields:
conditions.append(
# wrap column names in double quotes to handle columns starting with a number
SQL(
"{source_table_name}.{column} IS DISTINCT FROM {destination_table_name}.{column}"
).format(
Expand All @@ -67,6 +66,20 @@ def generate_distinct_condition(
return SQL(" OR ").join(conditions)


def generate_greater_than_condition(
source_table_name: str,
destination_table_name: str,
field: models.Field,
) -> Composable:
return SQL(
"{source_table_name}.{column} > {destination_table_name}.{column}"
).format(
source_table_name=Identifier(source_table_name),
column=Identifier(field.column),
destination_table_name=Identifier(destination_table_name),
)


def generate_distinct_null_condition(
source_table_name: str,
destination_table_name: str,
Expand Down Expand Up @@ -218,6 +231,7 @@ def generate_update_query(
pk_fields: Sequence[models.Field],
update_fields: Sequence[models.Field],
compare_fields: Sequence[models.Field],
update_where: Callable[[Sequence[models.Field], str, str], Composable] = None,
update_if_null_fields: Sequence[models.Field] = None,
) -> Composable:
update_if_null_fields = update_if_null_fields or []
Expand Down Expand Up @@ -253,34 +267,38 @@ def generate_update_query(
destination_table_name=table_name,
fields=pk_fields,
)
distinct_from_clause = generate_distinct_condition(
source_table_name=loading_table_name,
destination_table_name=table_name,
compare_fields=compare_fields,
)

if update_if_null_fields:
distinct_null_clause = generate_distinct_null_condition(
if update_where:
where_clause = update_where(update_fields, loading_table_name, table_name)
else:
where_clause = generate_distinct_condition(
source_table_name=loading_table_name,
destination_table_name=table_name,
compare_fields=update_if_null_fields,
compare_fields=compare_fields,
)
if compare_fields:
distinct_from_clause = SQL(" OR ").join(
[distinct_from_clause, distinct_null_clause]

if update_if_null_fields:
distinct_null_clause = generate_distinct_null_condition(
source_table_name=loading_table_name,
destination_table_name=table_name,
compare_fields=update_if_null_fields,
)
else:
distinct_from_clause = distinct_null_clause
if compare_fields:
where_clause = SQL(" OR ").join(
[where_clause, distinct_null_clause]
)
else:
where_clause = distinct_null_clause

return SQL(
"UPDATE {table_name} SET {update_clause} FROM {loading_table_name} WHERE {where_clause}"
).format(
table_name=Identifier(table_name),
update_clause=update_clause,
loading_table_name=Identifier(loading_table_name),
where_clause=SQL("({distinct_from_clause}) AND ({join_clause})").format(
distinct_from_clause=distinct_from_clause, join_clause=join_clause
),
where_clause=SQL("({where_clause}) AND ({join_clause})").format(
where_clause=where_clause, join_clause=join_clause
) if where_clause else join_clause
)


Expand Down
46 changes: 45 additions & 1 deletion tests/test_bulk_update_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime, timezone

from django.test import TestCase
from django_bulk_load import bulk_update_models
from django_bulk_load import bulk_update_models, generate_greater_than_condition
from .test_project.models import (
TestComplexModel,
TestForeignKeyModel,
Expand Down Expand Up @@ -328,3 +328,47 @@ def test_complex_update_if_null(self):
self.assertIsNone(saved_model2.string_field)
self.assertEqual(saved_model3.string_field, "d")
self.assertIsNone(saved_model3.datetime_field)

def test_custom_where(self):
# Should only update integer_field and datetime_field
model1 = TestComplexModel(
integer_field=1,
string_field="a",
)
model1.save()
model1.integer_field = 5
model1.string_field = "b"

model2 = TestComplexModel(
integer_field=3, string_field="c"
)
model2.save()
model2.integer_field = 2
model2.string_field = "c"


def update_where(fields, source_table_name, destination_table_name):
"""
Custom where clause where the new value must be greater than the old value
"""

# This should only update if the new value is greater than previous value
return generate_greater_than_condition(
source_table_name=source_table_name,
destination_table_name=destination_table_name,
field=TestComplexModel._meta.get_field("integer_field"),
)

bulk_update_models(
[model1, model2],
update_field_names=["integer_field", "string_field"],
update_where=update_where
)

# First model should be updated because 5 > 1
saved_model1 = TestComplexModel.objects.get(integer_field=5)
self.assertEqual(saved_model1.string_field, "b")

# Second model should not be updated because 2 < 3
saved_model2 = TestComplexModel.objects.get(integer_field=3)
self.assertEqual(saved_model2.string_field, "c")
45 changes: 44 additions & 1 deletion tests/test_bulk_upsert_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime, timezone

from django.test import TestCase
from django_bulk_load import bulk_upsert_models
from django_bulk_load import bulk_upsert_models, generate_greater_than_condition
from .test_project.models import (
TestComplexModel,
TestForeignKeyModel,
Expand Down Expand Up @@ -551,3 +551,46 @@ def test_complex_update_if_null(self):

self.assertEqual(saved_model3.string_field, "d")
self.assertIsNone(saved_model3.datetime_field)

def test_custom_update_where(self):
# Should only update integer_field and datetime_field
model1 = TestComplexModel(
integer_field=1,
string_field="a",
)
model1.save()
model1.integer_field = 5
model1.string_field = "b"

model2 = TestComplexModel(
integer_field=3, string_field="c"
)
model2.save()
model2.integer_field = 2
model2.string_field = "c"


def update_where(fields, source_table_name, destination_table_name):
"""
Custom where clause where the new value must be greater than the old value
"""

# This should only update if the new value is greater than previous value
return generate_greater_than_condition(
source_table_name=source_table_name,
destination_table_name=destination_table_name,
field=TestComplexModel._meta.get_field("integer_field"),
)

bulk_upsert_models(
[model1, model2],
update_where=update_where
)

# First model should be updated because 5 > 1
saved_model1 = TestComplexModel.objects.get(integer_field=5)
self.assertEqual(saved_model1.string_field, "b")

# Second model should not be updated because 2 < 3
saved_model2 = TestComplexModel.objects.get(integer_field=3)
self.assertEqual(saved_model2.string_field, "c")

0 comments on commit 0dc7a65

Please sign in to comment.