Skip to content
This repository has been archived by the owner on Dec 26, 2022. It is now read-only.

Implement DGraph TTL cleanup job #119

Merged
merged 14 commits into from
Jun 17, 2020
Merged
9 changes: 9 additions & 0 deletions src/python/grapl-dgraph-ttl/.chalice/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"version": "2.0",
"app_name": "grapl-dgraph-ttl",
"stages": {
"dev": {
"api_gateway_stage": "api"
}
}
}
2 changes: 2 additions & 0 deletions src/python/grapl-dgraph-ttl/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.chalice/deployments/
.chalice/venv/
151 changes: 151 additions & 0 deletions src/python/grapl-dgraph-ttl/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import datetime
import json
import os

from typing import Dict, Iterable, Iterator, Optional, Tuple, Union

from chalice import Chalice

from grapl_analyzerlib.grapl_client import (
GraphClient,
LocalMasterGraphClient,
MasterGraphClient,
)

IS_LOCAL = bool(os.environ.get("IS_LOCAL", False))
GRAPL_DGRAPH_TTL_S = int(os.environ.get("GRAPL_DGRAPH_TTL_S", "-1"))
GRAPL_LOG_LEVEL = os.environ.get("GRAPL_LOG_LEVEL", "ERROR")
GRAPL_TTL_DELETE_BATCH_SIZE = int(os.environ.get("GRAPL_TTL_DELETE_BATCH_SIZE", "100"))

app = Chalice(app_name="grapl-dgraph-ttl")
app.log.setLevel(GRAPL_LOG_LEVEL)


def query_batch(
client: GraphClient,
batch_size: int,
ttl_cutoff_ms: int,
last_uid: Optional[str] = None,
) -> Iterable[Dict[str, Union[Dict, str]]]:
after = "" if last_uid is None else f", after: {last_uid}"
paging = f"first: {batch_size}{after}"
query = f"""
{{
q(func: le(last_index_time, {ttl_cutoff_ms}), {paging}) {{
uid,
expand(_all_) {{ uid }}
}}
}}
"""

txn = client.txn()
try:
app.log.debug(f"retrieving batch: {query}")
batch = txn.query(query)
app.log.debug(f"retrieved batch: {batch}")
return json.loads(batch.json)["q"]
finally:
txn.discard()


def calculate_ttl_cutoff_ms(now: datetime.datetime, ttl_s: int) -> int:
delta = datetime.timedelta(seconds=ttl_s)
cutoff = now - delta
return int(cutoff.timestamp() * 1000)


def expired_entities(
client: GraphClient, now: datetime.datetime, ttl_s: int, batch_size: int
) -> Iterator[Iterable[Dict[str, Union[Dict, str]]]]:
ttl_cutoff_ms = calculate_ttl_cutoff_ms(now, ttl_s)

app.log.info(f"Pruning entities last indexed before {ttl_cutoff_ms}")

last_uid = None
while 1:
results = query_batch(client, batch_size, ttl_cutoff_ms, last_uid)

if len(results) > 0:
last_uid = results[-1]["uid"]
yield results

if len(results) < batch_size:
break # this was the last page of results


def nodes(entities: Iterable[Dict[str, Union[Dict, str]]]) -> Iterator[str]:
for entity in entities:
yield entity["uid"]


def edges(
entities: Iterable[Dict[str, Union[Dict, str]]]
) -> Iterator[Tuple[str, str, str]]:
for entity in entities:
uid = entity["uid"]
for key, value in entity.items():
if isinstance(value, list):
for v in value:
if isinstance(v, dict):
if len(v.keys()) == 1 and "uid" in v.keys():
yield (uid, key, v["uid"])


def delete_nodes(client: GraphClient, nodes: Iterator[str]) -> int:
del_ = [{"uid": uid} for uid in nodes]

txn = client.txn()
try:
mut = txn.create_mutation(del_obj=del_)
app.log.debug(f"deleting nodes: {mut}")
txn.mutate(mutation=mut, commit_now=True)
app.log.debug(f"deleted nodes: {json.dumps(del_)}")
return len(del_)
finally:
txn.discard()


def delete_edges(client: GraphClient, edges: Iterator[Tuple[str, str, str]]) -> int:
del_ = [
create_edge_obj(src_uid, predicate, dest_uid)
for src_uid, predicate, dest_uid in edges
]

txn = client.txn()
try:
mut = txn.create_mutation(del_obj=del_)
app.log.debug(f"deleting edges: {mut}")
txn.mutate(mutation=mut, commit_now=True)
app.log.debug(f"deleted edges: {json.dumps(del_)}")
return len(del_)
finally:
txn.discard()


def create_edge_obj(
src_uid: str, predicate: str, dest_uid: str
) -> Dict[str, Union[Dict, str]]:
if predicate.startswith("~"): # this is a reverse edge
return {"uid": dest_uid, predicate.lstrip("~"): {"uid": src_uid}}
else: # this is a forward edge
return {"uid": src_uid, predicate: {"uid": dest_uid}}


@app.lambda_function(name="prune_expired_subgraphs")
def prune_expired_subgraphs() -> None:
if GRAPL_DGRAPH_TTL_S > 0:
client = LocalMasterGraphClient() if IS_LOCAL else MasterGraphClient()

node_count = 0
edge_count = 0

for entities in expired_entities(
client,
now=datetime.datetime.utcnow(),
ttl_s=GRAPL_DGRAPH_TTL_S,
batch_size=GRAPL_TTL_DELETE_BATCH_SIZE,
):
edge_count += delete_edges(client, edges(entities))
node_count += delete_nodes(client, nodes(entities))

app.log.info(f"Pruned {node_count} nodes and {edge_count} edges")
3 changes: 3 additions & 0 deletions src/python/grapl-dgraph-ttl/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pydgraph==2.0.2
grapl_analyzerlib==0.2.*
chalice
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,7 @@ def add_reverse_edge_type(

def query_dgraph_type(client: GraphClient, type_name: str) -> List[str]:
query = f"""
schema(type: {type_name}) {{
}}
schema(type: {type_name}) {{ }}
"""
LOGGER.debug(f"query: {query}")
txn = client.txn(read_only=True)
Expand All @@ -206,11 +205,12 @@ def query_dgraph_type(client: GraphClient, type_name: str) -> List[str]:

pred_names = []

for field in res["types"][0]["fields"]:
pred_name = (
f"<{field['name']}>" if field["name"].startswith("~") else field["name"]
)
pred_names.append(pred_name)
if "types" in res:
for field in res["types"][0]["fields"]:
pred_name = (
f"<{field['name']}>" if field["name"].startswith("~") else field["name"]
)
pred_names.append(pred_name)

return pred_names

Expand Down
Loading