Skip to content

Commit

Permalink
Add generation of class_refatt_map for pydantic models
Browse files Browse the repository at this point in the history
  • Loading branch information
theferrit32 committed Jul 13, 2023
1 parent e34e7cb commit c24c494
Showing 1 changed file with 108 additions and 20 deletions.
128 changes: 108 additions & 20 deletions src/ga4gh/vrs/_internal/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,43 +17,131 @@
"""


from pydantic import BaseModel, Extra, Field, constr
from typing import Any, Dict, List, Optional, Union, Literal
from enum import Enum
import inspect
import logging
import os

import sys
import typing
import pkg_resources

from ga4gh.core import build_models, build_class_referable_attribute_map

_logger = logging.getLogger(__name__)

# specify VRSATILE_SCHEMA_DIR to use a schema other than the one embedded
# in vrs-python
schema_dir = os.environ.get("VRSATILE_SCHEMA_DIR", pkg_resources.resource_filename(__name__, "data/schemas/vrsatile"))
schema_path = schema_dir + "/merged.json"

models = None
class_refatt_map = None


def _load_models():
"""load/reload models from `schema_path`
This function facilitates reloading changes to the schema during
development.

def flatten(vals):
"""
Flattens vals recursively, lazily using yield
"""
def is_coll(thing):
"""
Return True if the thing looks like a collection.
This is not exhaustive, do not use in general.
"""
# return hasattr(thing, '__iter__') and not isinstance(thing, str) and not inspect.isclass(thing)
return type(thing) in [list, set]
if is_coll(vals):
for x in vals:
for fx in flatten(x):
yield fx
else:
yield vals


def flatten_type(t):
"""
Flattens a complex type into a list of constituent types
"""
if ('__origin__' in t.__dict__
and t.__origin__ != typing.Literal
and (t.__origin__ == typing.Union
or issubclass(t.__origin__, typing.List))):
return list(flatten([flatten_type(sub_t) for sub_t in t.__args__]))
else:
return [t]

global class_refatt_map, models
models = build_models(schema_path, standardize_names=False)
class_refatt_map = build_class_referable_attribute_map(models)
return models

def overlaps(a: list, b: list):
"""
Returns true if there are any elements in common between a and b
"""
intersection = set(a).intersection(set(b))
return len(intersection) > 0

_load_models()

def pydantic_class_refatt_map():
"""
Builds a map of class names to their field names that are referable types.
As in, types with an identifier that can be referred to elsewhere,
collapsed to that identifier and dereferenced.
"""
# Things defined here that are classes that inherit from BaseModel
this_module = sys.modules[__name__]
global_map = globals()
model_classes = list(filter(
lambda c: (inspect.isclass(c)
and issubclass(c, BaseModel)
and inspect.getmodule(c) == this_module),
[kv[1] for kv in global_map.items()]
))
# Types directly reffable
reffable_classes = list(filter(
lambda c: ('id' in c.__fields__
and hasattr(c, 'Ga4ghDigest')),
model_classes
))
# Types reffable because they are a union of reffable types
union_reffable_classes = list(filter(
lambda c: ('__root__' in c.__fields__
and overlaps(reffable_classes,
flatten_type(c.__fields__['__root__'].type_))),
model_classes
))
reffable_fields = {}
# Find any field whose type is a subclass of a reffable type,
# or which is a typing.List that includes a reffable type.
# Interestingly, ModelField.type_ is the member type for List types, not type List itself.
for model_class in model_classes:
# if model_class in union_reffable_classes:
# continue # These don't have fields other than __root__
fields = model_class.__fields__
class_reffable_fields = []
for fieldname, field in fields.items():
if fieldname == '__root__':
continue
field_type = field.type_ # a typing or normal annotation like str
# options
# raw class type annotation (int, str, dict, etc)
# typing.Literal, typing.Union, typing.Optional
# thankfully, issubclass actually handles these
# print(f'{model_class=} {fieldname=} {field_type=}')
if any([rc in flatten_type(field_type)
for rc in (reffable_classes
+ union_reffable_classes)]):
class_reffable_fields.append(fieldname)
if len(class_reffable_fields) > 0:
reffable_fields[model_class.__name__] = class_reffable_fields
return reffable_fields


def get_models():
this_module = sys.modules[__name__]
global_map = globals()
model_classes = list(filter(
lambda c: (inspect.isclass(c)
and issubclass(c, BaseModel)
and inspect.getmodule(c) == this_module),
[kv[1] for kv in global_map.items()]
))
return model_classes


class_refatt_map = pydantic_class_refatt_map()
models = get_models()


class Extension(BaseModel):
Expand Down

0 comments on commit c24c494

Please sign in to comment.