Skip to content

Commit

Permalink
DataManager: no longer export all data labels
Browse files Browse the repository at this point in the history
  • Loading branch information
bkpoon committed Oct 28, 2022
1 parent fde97ee commit 0caf7ea
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 47 deletions.
163 changes: 123 additions & 40 deletions iotbx/data_manager/miller_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class MillerArrayDataManager(DataManagerBase):
_labels_str = '_%s_labels'
_arrays_str = '_%s_arrays'
_user_selected_labels_str = '_user_selected_%s_labels'
_all_labels_str = '_all_%s_labels'

# template error message
_unrecognized_type_error_str = 'Unrecognized %s type, "%s," possible choices are %s.'
Expand Down Expand Up @@ -109,15 +110,55 @@ def get_miller_array_array_type(self, filename=None, label=None):
return self._get_miller_array_array_type(MillerArrayDataManager.datatype,
filename, label)

def get_miller_array_labels(self, filename=None):
def get_miller_array_all_labels(self, filename=None):
return self._get_all_array_labels(MillerArrayDataManager.datatype, filename)

def get_miller_array_labels(self, filename=None, return_all=True):
'''
Returns a list of array labels
Returns a list of array labels. If the list has no items, the
return_all parameter controls whether or not to return all the labels
Parameters
----------
filename : str
The filename for the labels, if None, the default filename is used
return_all : bool, optional
Keeps original behavior of returning all the labels
Returns
-------
labels : list
The list of labels
'''
return self._get_array_labels(MillerArrayDataManager.datatype, filename)
labels = self._get_array_labels(MillerArrayDataManager.datatype, filename)
if len(labels) == 0 and return_all:
labels = self._get_all_array_labels(MillerArrayDataManager.datatype, filename)
return labels

def get_miller_array_user_selected_labels(self, filename=None):
return self._get_user_selected_array_labels(MillerArrayDataManager.datatype, filename)

def set_miller_array_user_selected_labels(self, filename=None, labels=None):
'''
Set the array of selected labels.
Parameters
----------
filename : str
The filename for the labels, if None, the default filename is used
labels : list
The list of labels to set, if None, all labels in the file are selected
'''
if filename is None:
filename = self.get_default_miller_array_name()
all_labels = self.get_miller_array_all_labels(filename)
if labels is None:
labels = all_labels
for label in labels:
if label not in all_labels:
raise Sorry('{} does not exist in {}.'.format(label, filename))
getattr(self, self._user_selected_labels_str % MillerArrayDataManager.datatype)[filename] = labels

def get_miller_array_types(self, filename=None):
'''
Returns a dict of array types, keyed by label
Expand Down Expand Up @@ -154,6 +195,38 @@ def process_miller_array_file(self, filename):
self._process_file(MillerArrayDataManager.datatype, filename)
self._filter_miller_array_child_datatypes(filename)

def _detect_miller_array_array_type(self, array):
'''
Convenience function for getting the array_type of a Miller array
Parameters
----------
array : cctbx.miller.array
The Miller array object
Returns
-------
array_type : str
The array_type as determined by the is_*_array functions
'''
array_type = self._default_miller_array_array_type
if array.is_xray_amplitude_array():
array_type = 'amplitude'
elif array.is_xray_intensity_array():
array_type = 'intensity'
elif array.is_complex_array():
array_type = 'complex'
elif array.is_hendrickson_lattman_array():
array_type = 'hendrickson_lattman'
elif array.is_integer_array():
array_type = 'integer'
elif array.is_bool_array():
array_type = 'bool'
elif array.is_nonsense():
array_type = 'nonsense'
return array_type

def filter_miller_array_arrays(self, filename):
'''
Populate data structures with all arrays
Expand All @@ -173,22 +246,9 @@ def filter_miller_array_arrays(self, filename):
labels.append(label)
self._miller_array_arrays[filename][label] = array
self._miller_array_types[filename][label] = self._default_miller_array_type
self._miller_array_array_types[filename][label] = self._default_miller_array_array_type
if array.is_xray_amplitude_array():
self._miller_array_array_types[filename][label] = 'amplitude'
elif array.is_xray_intensity_array():
self._miller_array_array_types[filename][label] = 'intensity'
elif array.is_complex_array():
self._miller_array_array_types[filename][label] = 'complex'
elif array.is_hendrickson_lattman_array():
self._miller_array_array_types[filename][label] = 'hendrickson_lattman'
elif array.is_integer_array():
self._miller_array_array_types[filename][label] = 'integer'
elif array.is_bool_array():
self._miller_array_array_types[filename][label] = 'bool'
elif array.is_nonsense():
self._miller_array_array_types[filename][label] = 'nonsense'
self._miller_array_labels[filename] = labels
self._miller_array_array_types[filename][label] = self._detect_miller_array_array_type(array)
self._all_miller_array_labels[filename] = labels
self._miller_array_labels[filename] = []
self._user_selected_miller_array_labels[filename] = []

def write_miller_array_file(self, mtz_object, filename=Auto, overwrite=Auto):
Expand Down Expand Up @@ -265,14 +325,13 @@ def get_reflection_file_server(self, filenames=None, labels=None,
if len(filenames) > len(labels):
labels += [None]*(len(filenames) - len(labels))
assert len(filenames) == len(labels)

# check for user selected labels
selected_labels = deepcopy(labels)
for i, filename in enumerate(filenames):
current_selected_labels = self.get_miller_array_user_selected_labels(filename)
current_all_labels = labels[i]
if labels[i] is None:
current_all_labels = self.get_miller_array_labels(filename)
current_all_labels = self.get_miller_array_all_labels(filename)
if len(current_selected_labels) > 0:
selected_types = set()
# add selected labels
Expand All @@ -287,7 +346,6 @@ def get_reflection_file_server(self, filenames=None, labels=None,
current_selected_labels.append(label)
selected_labels[i] = current_selected_labels
labels = selected_labels

# force crystal symmetry if a crystal symmetry is provided
if crystal_symmetry is not None and force_symmetry is None:
force_symmetry = True
Expand Down Expand Up @@ -318,7 +376,7 @@ def get_reflection_file_server(self, filenames=None, labels=None,
force_symmetry=force_symmetry,
merge_equivalents=merge_equivalents)
if file_labels is None:
file_labels = self.get_miller_array_labels(filename)
file_labels = self.get_miller_array_all_labels(filename)
for miller_array in file_arrays:
label_name = miller_array.info().label_string()
# check array label
Expand All @@ -343,6 +401,7 @@ def _add_miller_array_phil_str(self, datatype):
# set up storage
# self._miller_array_types = {} # [filename] = type dict
# self._miller_array_array_types = {} # [filename] = type dict
# self._miller_array_all_labels = {} # [filename] = label list
# self._miller_array_labels = {} # [filename] = label list
# self._miller_array_arrays = {} # [filename] = array dict
# self._user_selected_miller_array_labels = {} # [filename] = array dict
Expand All @@ -353,13 +412,17 @@ def _add_miller_array_phil_str(self, datatype):
setattr(self, self._array_type_str % datatype, {})
setattr(self, self._default_array_type_str % datatype, 'unknown')
setattr(self, self._possible_array_types_str % datatype,
['amplitude', 'bool', 'complex', 'hendrickson_lattman', 'integer',
'intensity', 'nonsense', 'unknown'])
['unknown', 'amplitude', 'bool', 'complex', 'hendrickson_lattman', 'integer',
'intensity', 'nonsense'])
setattr(self, self._all_labels_str % datatype, {})
setattr(self, self._labels_str % datatype, {})
setattr(self, self._arrays_str % datatype, {})
setattr(self, self._user_selected_labels_str % datatype, {})

# custom PHIL section
# all the array labels in a file is not stored
# user_selected_labels stores the actual use selection
# labels stores the full selected label and metadata
custom_phil_str = '''
%s
.multiple = True
Expand All @@ -386,8 +449,8 @@ def _add_miller_array_phil_str(self, datatype):
.style = hidden
}
''' % (datatype,
' '.join(getattr(self, self._possible_types_str % datatype)),
' '.join(getattr(self, self._possible_array_types_str % datatype)))
' '.join(getattr(self, self._possible_types_str % datatype)),
' '.join(getattr(self, self._possible_array_types_str % datatype)))

# add fmodel PHIL
if self.supports('model'):
Expand All @@ -414,11 +477,12 @@ def _export_miller_array_phil_extract(self, datatype):
item_extract = getattr(self, '_custom_%s_phil' % datatype).extract()
item_extract = deepcopy(getattr(item_extract, '%s' % datatype)[0])
item_extract.file = filename
labels = self._get_array_labels(datatype, filename=filename)
types = self._get_array_types(datatype, filename=filename)
array_types = self._get_array_array_types(datatype, filename=filename)
user_selected_labels = self._get_user_selected_array_labels(datatype, filename=filename)
if len(labels) != len(types.keys()):
arrays = self.get_miller_arrays(filename=filename)
labels = [self._match_label(label, arrays) for label in user_selected_labels]
if len(labels) > len(types.keys()):
raise Sorry('Some labels do not have types.\n{}\n{}'.format(labels, list(types.keys())))
labels_extract = []
for label in labels:
Expand All @@ -435,7 +499,6 @@ def _export_miller_array_phil_extract(self, datatype):
def _match_label(self, label, miller_arrays):
'''
Convenience function for matching partially specified labels
A Sorry can be raised if no matching label is found.
Parameters
----------
Expand Down Expand Up @@ -466,13 +529,18 @@ def _load_miller_array_phil_extract(self, datatype, phil_extract):
if not hasattr(item_extract, 'file'):
raise Sorry('This PHIL is not properly defined for the %s datatype.\n There should be a parameter for the filename ("file").\n')

# process file
# process file, or last loaded file
if item_extract.file is None:
filenames = self.get_miller_array_names()
if len(filenames) == 0:
raise Sorry('No reflection files are available to continue processing PHIL.')
item_extract.file = filenames[-1]
getattr(self, 'process_%s_file' % datatype)(item_extract.file)

# check labels (if available)
if len(item_extract.labels) > 0 or len(item_extract.user_selected_labels) > 0:
# all labels in file
file_labels = getattr(self, self._labels_str % datatype)[item_extract.file]
file_labels = getattr(self, self._all_labels_str % datatype)[item_extract.file]

# labels from PHIL
phil_labels = []
Expand All @@ -484,8 +552,8 @@ def _load_miller_array_phil_extract(self, datatype, phil_extract):
label_name = label
if hasattr(label_name, 'name'):
label_name = label_name.name
phil_user_selected_labels.append(label_name)
if label_name not in file_labels:
phil_user_selected_labels.append(label_name)

# try matching
label_match = self._match_label(label_name, self.get_miller_arrays(filename=item_extract.file))
Expand All @@ -506,7 +574,13 @@ def _load_miller_array_phil_extract(self, datatype, phil_extract):
raise Sorry(self._unrecognized_type_error_str %
(datatype, label.array_type, ', '.join(
getattr(self, self._possible_array_types_str % datatype))))
phil_array_types[label_name] = label.array_type
# if a non-default type is set, assume it is manually set
# otherwise, automatically detect array_type
array_type = label.array_type
if array_type == self.get_default_miller_array_array_type():
array = self.get_miller_arrays(labels=[label_name], filename=item_extract.file)[0]
array_type = self._detect_miller_array_array_type(array)
phil_array_types[label_name] = array_type

# update storage
labels_storage = getattr(self, self._labels_str % datatype)[item_extract.file]
Expand Down Expand Up @@ -549,7 +623,7 @@ def _set_miller_array_type(self, datatype, filename=None, label=None,
if filename is None:
filename = self._get_default_name(datatype)
if label is None:
label = self._get_array_labels(datatype, filename)[0]
label = self._get_all_array_labels(datatype, filename)[0]
if array_type is None:
array_type = getattr(self, self._default_type_str % datatype)
elif array_type not in getattr(self, self._possible_types_str % datatype):
Expand All @@ -563,7 +637,7 @@ def _get_miller_array_type(self, datatype, filename=None, label=None):
if filename is None:
filename = self._get_default_name(datatype)
if label is None:
label = self._get_array_labels(datatype, filename)[0]
label = self._get_all_array_labels(datatype, filename)[0]
types = self._get_array_types(datatype, filename)
return types.get(label, getattr(self, self._default_type_str % datatype))

Expand All @@ -572,7 +646,7 @@ def _set_miller_array_array_type(self, datatype, filename=None, label=None,
if filename is None:
filename = self._get_default_name(datatype)
if label is None:
label = self._get_array_labels(datatype, filename)[0]
label = self._get_all_array_labels(datatype, filename)[0]
if array_type is None:
array_type = getattr(self, self._default_array_type_str % datatype)
elif array_type not in getattr(self, self._possible_array_types_str % datatype):
Expand All @@ -586,7 +660,7 @@ def _get_miller_array_array_type(self, datatype, filename=None, label=None):
if filename is None:
filename = self._get_default_name(datatype)
if label is None:
label = self._get_array_labels(datatype, filename)[0]
label = self._get_all_array_labels(datatype, filename)[0]
types = self._get_array_array_types(datatype, filename)
return types.get(label, getattr(self, self._default_array_type_str % datatype))

Expand Down Expand Up @@ -616,6 +690,13 @@ def _check_miller_array_storage_dict(self, datatype, storage_dict, filename):
if filename not in storage_dict.keys():
raise Sorry('There are no known %s arrays in %s' % (datatype, filename))

def _get_all_array_labels(self, datatype, filename=None):
filename = self._check_miller_array_default_filename(datatype, filename)
storage_dict = getattr(self, self._all_labels_str % datatype)
self._check_miller_array_storage_dict(datatype, storage_dict, filename)
labels = storage_dict[filename]
return labels

def _get_array_labels(self, datatype, filename=None):
filename = self._check_miller_array_default_filename(datatype, filename)
storage_dict = getattr(self, self._labels_str % datatype)
Expand Down Expand Up @@ -649,7 +730,7 @@ def _get_arrays(self, datatype, filename=None, labels=None):
if filename not in getattr(self, 'get_%s_names' % datatype)():
self.process_miller_array_file(filename)
if labels is None:
labels = self._get_array_labels(datatype, filename=filename)
labels = self._get_all_array_labels(datatype, filename=filename)
else:
if not isinstance(labels, list):
raise Sorry('The labels argument should be a list of labels')
Expand Down Expand Up @@ -680,6 +761,7 @@ def _child_filter_arrays(self, datatype, filename, known_labels):
miller_arrays = data.as_miller_arrays(merge_equivalents=merge_equivalents)
labels = []
types = {}
# array_types = {}
datatype_dict = getattr(self, self._arrays_str % datatype)
for array in miller_arrays:
label = set(array.info().labels)
Expand All @@ -691,12 +773,13 @@ def _child_filter_arrays(self, datatype, filename, known_labels):
datatype_dict[filename] = {}
datatype_dict[filename][label] = array
types[label] = getattr(self, self._default_type_str % datatype)
# array_types[label] = getattr(self, self._default_array_type_str % datatype)

# if arrays exist, start tracking
if len(labels) > 1:
getattr(self, self._labels_str % datatype)[filename] = labels
getattr(self, self._type_str % datatype)[filename] = types
getattr(self, self._array_type_str % datatype)[filename] = {}
# getattr(self, self._array_type_str % datatype)[filename] = array_types
getattr(self, self._user_selected_labels_str % datatype)[filename] = []
self._add(datatype, filename, data)

Expand Down
Loading

0 comments on commit 0caf7ea

Please sign in to comment.