Skip to content

Commit

Permalink
add cell to struc (from_ase_atoms) & fix load_grid bug in mgp & repla…
Browse files Browse the repository at this point in the history
…ce spc_set with 'sort'
  • Loading branch information
YuuuXie committed Jun 16, 2020
1 parent 941afad commit 43dc506
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 11 deletions.
6 changes: 2 additions & 4 deletions flare/mgp/map2b.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,10 @@ def build_bond_struc(self, species_list):

# 2 body (2 atoms (1 bond) config)
self.spc = []
self.spc_set = []
for spc1_ind, spc1 in enumerate(species_list):
for spc2 in species_list[spc1_ind:]:
species = [spc1, spc2]
self.spc.append(species)
self.spc_set.append(set(species))
self.spc.append(sorted(species))


def get_arrays(self, atom_env):
Expand All @@ -43,7 +41,7 @@ def get_arrays(self, atom_env):

def find_map_index(self, spc):
# use set because of permutational symmetry
return self.spc_set.index(set(spc))
return self.spc.index(sorted(spc))



Expand Down
2 changes: 0 additions & 2 deletions flare/mgp/map3b.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def build_bond_struc(self, species_list):

# 2 body (2 atoms (1 bond) config)
self.spc = []
self.spc_set = []
N_spc = len(species_list)
for spc1_ind in range(N_spc):
spc1 = species_list[spc1_ind]
Expand All @@ -40,7 +39,6 @@ def build_bond_struc(self, species_list):
spc3 = species_list[spc3_ind]
species = [spc1, spc2, spc3]
self.spc.append(species)
self.spc_set.append(set(species))


def get_arrays(self, atom_env):
Expand Down
5 changes: 2 additions & 3 deletions flare/mgp/mapxb.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def __init__(self,
self.n_sample = n_sample

self.spc = []
self.spc_set = []
self.maps = []
self.kernel_info = None
self.hyps_mask = hyps_mask
Expand Down Expand Up @@ -170,7 +169,7 @@ def as_dict(self) -> dict:
out_dict['bounds'] = [m.bounds for m in self.maps]

# rm keys since they are built in the __init__ function
key_list = ['singlexbody', 'spc_set']
key_list = ['singlexbody', 'spc']
for key in key_list:
if out_dict.get(key) is not None:
del out_dict[key]
Expand Down Expand Up @@ -477,7 +476,7 @@ def build_map(self, GP):

grid_path = f'{self.load_grid}/mgp_grids/{self.bodies}_{self.species_code}'
y_mean = np.load(f'{grid_path}_mean.npy')
y_var = np.load(f'{grid_path}_var.npy')
y_var = np.load(f'{grid_path}_var.npy', allow_pickle=True)

self.mean.set_values(y_mean)
if not self.mean_only:
Expand Down
7 changes: 5 additions & 2 deletions flare/struc.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,15 +265,18 @@ def from_dict(dictionary: dict) -> 'flare.struc.Structure':
return struc

@staticmethod
def from_ase_atoms(atoms: 'ase.Atoms') -> 'flare.struc.Structure':
def from_ase_atoms(atoms: 'ase.Atoms', cell=None) -> 'flare.struc.Structure':
"""
From an ASE Atoms object, return a FLARE structure
:param atoms: ASE Atoms object
:type atoms: ASE Atoms object
:return: A FLARE structure from an ASE atoms object
"""
struc = Structure(cell=np.array(atoms.cell),

if cell is None:
cell = np.array(atoms.cell)
struc = Structure(cell=cell,
positions=atoms.positions,
species=atoms.get_chemical_symbols())
return struc
Expand Down

0 comments on commit 43dc506

Please sign in to comment.