|
| 1 | +import numpy as np |
| 2 | +from typing import Dict, List, Optional, TextIO, Tuple, Union |
| 3 | +from pathlib import Path |
| 4 | +from collections import deque |
| 5 | +from os.path import splitext |
| 6 | +from soluanalysis.james import Atom, System |
| 7 | +import h5py |
| 8 | +import soluanalysis as solu |
| 9 | +import numbers |
| 10 | + |
| 11 | + |
| 12 | +def save_system_to_hdf5(system: solu.james.System, hdf5_group: h5py.Group): |
| 13 | + """Save a System object to an HDF5 group. |
| 14 | +
|
| 15 | + Args: |
| 16 | + system (solu.james.System): System object to serialize |
| 17 | + hdf5_group (h5py.Group): The HDF5 group to be saved to |
| 18 | + """ |
| 19 | + |
| 20 | + atom_ids = np.array([atom.id for atom in system.atoms], dtype=np.int32) |
| 21 | + atom_types = np.array([atom.type for atom in system.atoms], dtype=np.int32) |
| 22 | + mol_ids = np.array( |
| 23 | + [atom.id if atom.mol_id is None else atom.mol_id for atom in system.atoms], |
| 24 | + dtype=np.int32, |
| 25 | + ) |
| 26 | + positions = np.array([atom.position for atom in system.atoms], dtype=np.float64) |
| 27 | + |
| 28 | + hdf5_group.create_dataset("atom_ids", data=atom_ids) |
| 29 | + hdf5_group.create_dataset("atom_types", data=atom_types) |
| 30 | + hdf5_group.create_dataset("mol_ids", data=mol_ids) |
| 31 | + hdf5_group.create_dataset("positions", data=positions) |
| 32 | + |
| 33 | + if system.box is not None: |
| 34 | + hdf5_group.create_dataset("box", data=np.array(system.box, dtype=np.float64)) |
| 35 | + if system.boxLo is not None: |
| 36 | + hdf5_group.create_dataset( |
| 37 | + "boxLo", data=np.array(system.boxLo, dtype=np.float64) |
| 38 | + ) |
| 39 | + |
| 40 | + |
| 41 | +def read_ion_pairs_from_hdf5( |
| 42 | + file_path: Path, |
| 43 | +) -> Tuple[ |
| 44 | + Dict[int, Dict[int, List[List[int]]]], |
| 45 | + List[int], |
| 46 | + solu.james.System, |
| 47 | + int, |
| 48 | + solu.james.WriteIdentifier, |
| 49 | +]: |
| 50 | + """Reads the HDF5 file and reconstructs a dictionary with the time series information about the ion pairs |
| 51 | +
|
| 52 | + Args: |
| 53 | + file_path (Path): The HDF5 file to read from |
| 54 | +
|
| 55 | + Returns: |
| 56 | + Tuple[Dict[int, Dict[int, List[List[int]]]], List[int], int, solu.james.WriteIdentifier]: A tuple containing |
| 57 | + 1) the dictionary with the ion pairs, |
| 58 | + 2) timesteps, |
| 59 | + 3) System object |
| 60 | + 4) max_depth, |
| 61 | + 5) writeIdentifier |
| 62 | + """ |
| 63 | + time_series_dict = {} |
| 64 | + |
| 65 | + enum_mapping = { |
| 66 | + "WriteIdentifier.AtomID": solu.james.WriteIdentifier.AtomID, |
| 67 | + "WriteIdentifier.Index": solu.james.WriteIdentifier.Index, |
| 68 | + } |
| 69 | + |
| 70 | + with h5py.File(file_path, "r") as file: |
| 71 | + # Read the metadata |
| 72 | + max_depth = file.attrs["max_depth"] |
| 73 | + identifier_str = file.attrs["writeIdentifier"] |
| 74 | + |
| 75 | + # Read the timesteps |
| 76 | + timesteps = file["timesteps"][:].tolist() |
| 77 | + |
| 78 | + # Read the representative System object |
| 79 | + system_group = file["system"] |
| 80 | + system = read_system_from_hdf5(system_group) |
| 81 | + |
| 82 | + # Iterate over the timesteps |
| 83 | + for timestep in timesteps: |
| 84 | + timestep_group = file[str(timestep)] |
| 85 | + groups = {} |
| 86 | + |
| 87 | + # Iterate over the lengths within each timestep |
| 88 | + for length in timestep_group.keys(): |
| 89 | + length_group = timestep_group[length] |
| 90 | + |
| 91 | + # Read the numpy array and convert it back to a list of lists |
| 92 | + data_array = length_group["ion_pairs"][:] |
| 93 | + lists = data_array.tolist() |
| 94 | + |
| 95 | + groups[int(length)] = lists |
| 96 | + |
| 97 | + time_series_dict[int(timestep)] = groups |
| 98 | + |
| 99 | + # Return the time series, timesteps, max_depth, the enum, and the number of atoms |
| 100 | + return ( |
| 101 | + time_series_dict, |
| 102 | + timesteps, |
| 103 | + system, |
| 104 | + max_depth, |
| 105 | + enum_mapping.get(identifier_str), |
| 106 | + ) |
| 107 | + |
| 108 | + |
| 109 | +def read_system_from_hdf5(hdf5_group: h5py.Group) -> solu.james.System: |
| 110 | + """Read a System object from an HDF5 group. |
| 111 | +
|
| 112 | + Args: |
| 113 | + hdf5_group (h5py.Group): HDF5 group, from which the System object will be reconstructed |
| 114 | +
|
| 115 | + Returns: |
| 116 | + solu.james.System: Reconstructed System object |
| 117 | + """ |
| 118 | + atom_ids = hdf5_group["atom_ids"][:] |
| 119 | + atom_types = hdf5_group["atom_types"][:] |
| 120 | + mol_ids = hdf5_group["mol_ids"][:] |
| 121 | + positions = hdf5_group["positions"][:] |
| 122 | + |
| 123 | + # Reconstruct atoms list |
| 124 | + atoms = [ |
| 125 | + solu.james.Atom(atom_id, atom_type, mol_id, position) |
| 126 | + for atom_id, atom_type, mol_id, position in zip( |
| 127 | + atom_ids, atom_types, mol_ids, positions |
| 128 | + ) |
| 129 | + ] |
| 130 | + |
| 131 | + # Read optional attributes |
| 132 | + box = hdf5_group["box"][:] if "box" in hdf5_group else None |
| 133 | + boxLo = hdf5_group["boxLo"][:] if "boxLo" in hdf5_group else None |
| 134 | + |
| 135 | + # Reconstruct the System object |
| 136 | + system = solu.james.System(atoms, box, boxLo) |
| 137 | + |
| 138 | + return system |
| 139 | + |
| 140 | + |
| 141 | +def save_ion_pairs_to_hdf5( |
| 142 | + file_path: Path, |
| 143 | + time_series_dict: Dict[int, Dict[int, List[List[int]]]], |
| 144 | + system: solu.james.System, |
| 145 | + max_depth: int, |
| 146 | + write_identifier: solu.james.WriteIdentifier, |
| 147 | + **compression_kwargs: Union[str, int], |
| 148 | +) -> None: |
| 149 | + """Save the ion pairs per time step, sorted according to length into an HDF5 file. |
| 150 | +
|
| 151 | + Args: |
| 152 | + file_path (Path): File path of the HDF5 file to write to |
| 153 | + time_series_dict (Dict[int, Dict[int, List[List[int]]]]): Dictionary containing timesteps and ion pairs. |
| 154 | + The keys of the outer dictionary are timesteps, and the keys of the inner dictionary are ion pair lengths |
| 155 | + system (solu.james.System): Representative System object, containing indices, atom IDs, atom types, molecular IDs |
| 156 | + max_depth (int): Maximum length of the ion pair |
| 157 | + write_identifier (solu.james.WriteIdentifier): enum class which describes whether the elements correspond to |
| 158 | + atom IDs or indices in the System object. |
| 159 | + compression_kwargs(Union[str, int]): additional compression options for the create_dataset command in h5py. |
| 160 | + For instance, compression="gzip" and compression_opts=4 |
| 161 | + """ |
| 162 | + # Extract timesteps from the keys of time_series_dict and sort |
| 163 | + timesteps = sorted(time_series_dict.keys()) |
| 164 | + |
| 165 | + with h5py.File(file_path, "w") as file: |
| 166 | + # Save metadata |
| 167 | + file.attrs["max_depth"] = max_depth |
| 168 | + file.attrs["writeIdentifier"] = str(write_identifier) # convert enum to string |
| 169 | + |
| 170 | + # Save the timesteps as a separate dataset |
| 171 | + file.create_dataset( |
| 172 | + "timesteps", data=np.array(timesteps, dtype=np.int32), **compression_kwargs |
| 173 | + ) |
| 174 | + |
| 175 | + # Save the System object |
| 176 | + system_group = file.create_group("system") |
| 177 | + save_system_to_hdf5(system, system_group) |
| 178 | + |
| 179 | + # Now save the ion pairs per timestep into separate groups (each length would be in a different group) |
| 180 | + for timestep in timesteps: |
| 181 | + groups = time_series_dict[timestep] |
| 182 | + |
| 183 | + # Create a group for each timestep (timesteps are unique) |
| 184 | + timestep_group = file.create_group(str(timestep)) |
| 185 | + |
| 186 | + for length, data in groups.items(): |
| 187 | + # Create a subgroup for each length (can go upto max_length) |
| 188 | + length_group = timestep_group.create_group(str(length)) |
| 189 | + |
| 190 | + # Convert the list of lists to a numpy array |
| 191 | + ion_pair_data = np.array(data, dtype=np.int32) |
| 192 | + |
| 193 | + # Save the numpy array to the HDF5 file |
| 194 | + length_group.create_dataset( |
| 195 | + "ion_pairs", data=ion_pair_data, **compression_kwargs |
| 196 | + ) |
0 commit comments