Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 169 additions & 20 deletions src/boutdata/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def collect(
strict=False,
tind_auto=False,
datafile_cache=None,
zguards=False,
):
"""Collect a variable from a set of BOUT++ outputs.

Expand Down Expand Up @@ -206,6 +207,7 @@ def getDataFile(i):
prefix,
strict,
datafile_cache,
zguards,
)

nfiles = len(file_list)
Expand All @@ -216,6 +218,7 @@ def getDataFile(i):
f,
xguards=xguards,
yguards=yguards,
zguards=zguards,
tind=tind,
xind=xind,
yind=yind,
Expand Down Expand Up @@ -259,14 +262,17 @@ def getDataFile(i):

if info:
print(
"mxsub = {} mysub = {} mz = {}\n".format(
grid_info["mxsub"], grid_info["mysub"], grid_info["nz"]
"mxsub = {} mysub = {} mzsub = {}\n".format(
grid_info["mxsub"], grid_info["mysub"], grid_info["mzsub"]
)
)

print(
"nxpe = {}, nype = {}, npes = {}\n".format(
grid_info["nxpe"], grid_info["nype"], grid_info["npes"]
"nxpe = {}, nype = {}, nzpe = {} npes = {}\n".format(
grid_info["nxpe"],
grid_info["nype"],
grid_info["nzpe"],
grid_info["npes"],
)
)
if grid_info["npes"] < nfiles:
Expand Down Expand Up @@ -316,6 +322,7 @@ def getDataFile(i):
zind=zind,
xguards=xguards,
yguards=(yguards is not False),
zguards=zguards,
info=info,
)
if is_fieldperp:
Expand All @@ -327,7 +334,7 @@ def getDataFile(i):
varname,
yindex_global,
temp_yindex,
i // grid_info["nxpe"],
grid_info["nype"],
fieldperp_yproc,
var_attributes,
temp_f_attributes,
Expand All @@ -337,11 +344,12 @@ def getDataFile(i):
f.close()

# if a step was requested in x or y, need to apply it here
data = _apply_step(data, dimensions, xind.step, yind.step)
data = _apply_step(data, dimensions, xind.step, yind.step, zind.step)

# Finished looping over all files
if info:
sys.stdout.write("\n")

return BoutArray(data, attributes=var_attributes)


Expand All @@ -359,6 +367,7 @@ def _collect_from_single_file(
prefix,
strict,
datafile_cache,
zguards,
):
"""
Collect data from a single file
Expand Down Expand Up @@ -394,6 +403,11 @@ def _collect_from_single_file(
except KeyError:
myg = 0
print(f"MYG not found, setting to {myg}")
try:
mzg = f["MZG"]
except KeyError:
mzg = 0
print(f"MZG not found, setting to {mzg}")

if xguards:
nx = f["nx"]
Expand All @@ -407,7 +421,11 @@ def _collect_from_single_file(
ny = ny + 2 * myg
else:
ny = f["ny"]
nz = f["MZ"]

if zguards:
nz = f["nz"] + 2 * mzg
else:
nz = f["nz"]
t_array = f.read("t_array")
if t_array is None:
nt = 1
Expand All @@ -429,6 +447,8 @@ def _collect_from_single_file(
xind = slice(xind.start + mxg, xind.stop + mxg, xind.step)
if not yguards:
yind = slice(yind.start + myg, yind.stop + myg, yind.step)
if not zguards:
zind = slice(zind.start + mzg, zind.stop + mzg, zind.step)

dim_ranges = {"t": tind, "x": xind, "y": yind, "z": zind}
ranges = [dim_ranges.get(dim, None) for dim in dimensions]
Expand Down Expand Up @@ -469,7 +489,7 @@ def _read_scalar(f, varname, dimensions, var_attributes, tind):
return BoutArray(data, attributes=var_attributes)


def _apply_step(data, dimensions, xstep, ystep):
def _apply_step(data, dimensions, xstep, ystep, zstep):
"""
Apply steps of xind and yind slices to an array

Expand All @@ -492,6 +512,9 @@ def _apply_step(data, dimensions, xstep, ystep):
if "y" in dimensions:
slices[dimensions.index("y")] = slice(None, None, ystep)

if "z" in dimensions:
slices[dimensions.index("z")] = slice(None, None, zstep)

return data[tuple(slices)]


Expand All @@ -510,6 +533,7 @@ def _collect_from_one_proc(
zind,
xguards,
yguards,
zguards,
info,
parallel_read=False,
):
Expand Down Expand Up @@ -594,15 +618,20 @@ def _collect_from_one_proc(

nxpe = grid_info["nxpe"]
nype = grid_info["nype"]
nzpe = grid_info["nzpe"]
mxsub = grid_info["mxsub"]
mysub = grid_info["mysub"]
mzsub = grid_info["mzsub"]
mxg = grid_info["mxg"]
myg = grid_info["myg"]
mzg = grid_info["mzg"]
yproc_upper_target = grid_info["yproc_upper_target"]

# Get X and Y processor indices
pe_yind = i // nxpe
pe_xind = i % nxpe
# Get processor indices. `grid_info` only has global data, whereas these are
# specific to each file
pe_xind = datafile.read("PE_XIND") or i % nxpe
pe_yind = datafile.read("PE_YIND") or (i // nxpe) % nype
pe_zind = datafile.read("PE_ZIND") or i // (nxpe * nype)

inrange = True

Expand All @@ -624,18 +653,38 @@ def _collect_from_one_proc(
yguards, yind, pe_yind, nype, yproc_upper_target, mysub, myg, inrange
)

is_field2d = dimensions == ("t", "x", "y") or dimensions == ("x", "y")
if is_field2d:
# Field2Ds do not have a z-dimension, so cannot be sliced in z and should
# always be read regardless of the value of zind (so we should not change
# inrange by checking the z-range).
# zstart, zstop, zgstart and zgstop are set only to avoid errors in 'info'
# messages.
zstart = 0
zstop = 1
zgstart = 0
zgstop = 1
else:
zstart, zstop, zgstart, zgstop, inrange = _get_z_range(
zguards, zind, pe_zind, nzpe, mzsub, mzg, inrange
)

if not inrange:
return None, None # Don't need this file

local_dim_slices = {
"t": tind,
"x": slice(xstart, xstop),
"y": slice(ystart, ystop),
"z": zind,
"z": slice(zstart, zstop),
}
local_slices = tuple(local_dim_slices.get(dim, None) for dim in dimensions)

global_dim_slices = {"x": slice(xgstart, xgstop), "y": slice(ygstart, ygstop)}
global_dim_slices = {
"x": slice(xgstart, xgstop),
"y": slice(ygstart, ygstop),
"z": slice(zgstart, zgstop),
}
if parallel_read:
# When reading in parallel, we are always reading into a 4-dimensional shared
# array. Should not reach this function unless we only have dimensions in
Expand All @@ -652,7 +701,8 @@ def _collect_from_one_proc(

if info:
print(
f"\rReading from {i}: [{xstart}-{xstop - 1}][{ystart}-{ystop - 1}] -> [{xgstart}-{xgstop - 1}][{ygstart}-{ygstop - 1}]\n"
f"\rReading from {i}: [{xstart}-{xstop - 1}][{ystart}-{ystop - 1}][{zstart}-{zstop - 1}] "
f"-> [{xgstart}-{xgstop - 1}][{ygstart}-{ygstop - 1}][{zgstart}-{zgstop - 1}]\n"
)

if is_fieldperp:
Expand Down Expand Up @@ -684,8 +734,7 @@ def _fieldperp_from_this(nype, pe_yind, mysub, myg, temp_yindex):

def _check_local_range_lower(start, stop, lower_index, inrange):
"""
Utility function for _get_x_range and _get_y_range. Checks inner or lower edge of
local ranges.
Utility function for `_get_{x,y,z}_range`. Checks inner or lower edge of local ranges.

Parameters
----------
Expand Down Expand Up @@ -916,6 +965,83 @@ def _get_y_range(yguards, yind, pe_yind, nype, yproc_upper_target, mysub, myg, i
return ystart, ystop, ygstart, ygstop, inrange


def _get_z_range(zguards, zind, pe_zind, nzpe, mzsub, mzg, inrange):
"""
Get local ranges of z-indices

Parameters
----------
zguards : bool
Include z-boundaries?
zind : slice
Global slice to apply to z-dimension
pe_zind : int
z-indez of the processor
nzpe : int
Number of processors in the z-direction
mzsub : int
Number of grid cells (excluding guard cells) in the z-direction on a single
procssor
mzg : int
Number of guard cells in the z-direction
inrange : bool
Does the processor have data to read?

Returns
-------
zstart : int
Local z-index to start reading
zstop : int
Local z-index to stop reading
zgstart : int
Global z-index to start putting data
zgstop : int
Global z-index to stop putting data
inrange : bool
Updated version of inrange - changed to False if this processor has no data to
read
"""
# Local ranges
if zguards:
zstart = zind.start - pe_zind * mzsub
zstop = zind.stop - pe_zind * mzsub

# Check lower z boundary
if pe_zind == 0:
# Keeping inner boundary
zstart, inrange = _check_local_range_lower(zstart, zstop, 0, inrange)
else:
zstart, inrange = _check_local_range_lower(zstart, zstop, mzg, inrange)

# Upper z boundary
if pe_zind == (nzpe - 1):
# Keeping outer boundary
zstop, inrange = _check_local_range_upper(
zstart, zstop, mzsub + 2 * mzg, inrange
)
else:
zstop, inrange = _check_local_range_upper(
zstart, zstop, mzsub + mzg, inrange
)

else:
zstart = zind.start - pe_zind * mzsub + mzg
zstop = zind.stop - pe_zind * mzsub + mzg

zstart, inrange = _check_local_range_lower(zstart, zstop, mzg, inrange)
zstop, inrange = _check_local_range_upper(zstart, zstop, mzsub + mzg, inrange)

# Global ranges
if zguards:
zgstart = zstart + pe_zind * mzsub - zind.start
zgstop = zstop + pe_zind * mzsub - zind.start
else:
zgstart = zstart + pe_zind * mzsub - mzg - zind.start
zgstop = zstop + pe_zind * mzsub - mzg - zind.start

return zstart, zstop, zgstart, zgstop, inrange


def _check_fieldperp_attributes(
varname,
yindex_global,
Expand Down Expand Up @@ -950,7 +1076,17 @@ def _check_fieldperp_attributes(


def _get_grid_info(
f, *, xguards, yguards, tind, xind, yind, zind, nfiles, all_vars_info=False
f,
*,
xguards,
yguards,
zguards: bool,
tind,
xind,
yind,
zind,
nfiles,
all_vars_info=False,
):
"""Get the grid info from an open DataFile

Expand Down Expand Up @@ -993,8 +1129,10 @@ def load_and_check(varname):

mxg = int(load_and_check("MXG"))
myg = int(load_and_check("MYG"))
mzg = int(f.read("MZG") or 0)
mxsub = int(load_and_check("MXSUB"))
mysub = int(load_and_check("MYSUB"))
mzsub = int(f.read("MZSUB") or mz)
try:
nxpe = int(f["NXPE"])
except KeyError:
Expand All @@ -1006,6 +1144,9 @@ def load_and_check(varname):
nype = nfiles
print(f"NYPE not found, setting to {nype}")

# Don't warn, most files won't have this
nzpe = int(f.get("NZPE", 1))

if "t_array" in f.keys():
nt = len(f.read("t_array"))
else:
Expand All @@ -1031,7 +1172,12 @@ def load_and_check(varname):
else:
ny = mysub * nype

nz = mz - 1 if version < 3.5 else mz
if zguards:
nz = mzsub * nzpe + 2 * mzg
elif version < 3.5:
nz = mz - 1
else:
nz = mzsub * nzpe

tind = _convert_to_nice_slice(tind, nt, "tind")
xind = _convert_to_nice_slice(xind, nx, "xind")
Expand All @@ -1040,7 +1186,7 @@ def load_and_check(varname):

xsize = xind.stop - xind.start
ysize = yind.stop - yind.start
zsize = int(np.ceil(float(zind.stop - zind.start) / zind.step))
zsize = zind.stop - zind.start
tsize = int(np.ceil(float(tind.stop - tind.start) / tind.step))

# Map between dimension names and output size
Expand All @@ -1053,13 +1199,16 @@ def load_and_check(varname):
"mxsub": mxsub,
"myg": myg,
"mysub": mysub,
"mzg": mzg,
"mzsub": mzsub,
"nt": nt,
"npes": nxpe * nype,
"npes": nxpe * nype * nzpe,
"nx": nx,
"nxpe": nxpe,
"ny": ny,
"nype": nype,
"nz": nz,
"nzpe": nzpe,
"sizes": sizes,
"varNames": varNames,
"yproc_upper_target": yproc_upper_target,
Expand Down
Loading
Loading