Skip to content
Open

Stop #90

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ class NodesInfo(NamedTuple):


class FragmentsGlobals(NamedTuple):
stop: jnp.ndarray # [n_graph] bool array (only for training)
target_positions: jnp.ndarray # [n_graph, 3] float array (only for training)
target_species: jnp.ndarray # [n_graph] int array (only for training)


class FragmentsNodes(NamedTuple):
positions: jnp.ndarray # [n_node, 3] float array
species: jnp.ndarray # [n_node] int array
target_species_probs: jnp.ndarray # [n_node, n_species] float array (only for training)
finished: jnp.ndarray # [n_node] bool array
target_species_probs: jnp.ndarray # [n_node, n_species + 1] float array (only for training)


class Fragments(jraph.GraphsTuple):
Expand Down
144 changes: 93 additions & 51 deletions fragments.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import functools
from typing import Iterator
from typing import Iterator, Tuple

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -43,15 +43,16 @@ def generate_fragments(
) # [n_edge]

try:
rng, visited_nodes, frag = _make_first_fragment(
rng, visited_nodes, finished, frag = _make_first_fragment(
rng, graph, dist, n_species, nn_tolerance, max_radius, mode
)
yield frag

for _ in range(n - 2):
rng, visited_nodes, frag = _make_middle_fragment(
while len(visited_nodes) < n:
rng, visited_nodes, finished, frag = _make_middle_fragment(
rng,
visited_nodes,
finished,
graph,
dist,
n_species,
Expand All @@ -63,9 +64,9 @@ def generate_fragments(
except ValueError:
pass
else:
assert len(visited_nodes) == n

yield _make_last_fragment(graph, n_species)
while jnp.sum(finished) < n:
rng, finished, frag = _make_last_fragments(rng, finished, graph, n_species)
yield frag


def _make_first_fragment(rng, graph, dist, n_species, nn_tolerance, max_radius, mode):
Expand All @@ -87,115 +88,151 @@ def _make_first_fragment(rng, graph, dist, n_species, nn_tolerance, max_radius,
if len(targets) == 0:
raise ValueError("No targets found.")

num_nodes = graph.nodes.positions.shape[0]
species_probability = (
jnp.zeros((graph.nodes.positions.shape[0], n_species))
.at[first_node]
jnp.zeros((num_nodes, n_species + 1))
.at[first_node, :n_species]
.set(_normalized_bitcount(graph.nodes.species[targets], n_species))
)

# pick a random target
rng, k = jax.random.split(rng)
target = jax.random.choice(k, targets)

finished = jnp.zeros((num_nodes,), dtype=bool)
sample = _into_fragment(
graph,
visited=jnp.array([first_node]),
focus_node=first_node,
target_species_probability=species_probability,
target_node=target,
stop=False,
finished=finished,
)

visited = jnp.array([first_node, target])
return rng, visited, sample
return rng, visited, finished, sample


def _make_middle_fragment(
rng, visited, graph, dist, n_species, nn_tolerance, max_radius, mode
rng, visited, finished, graph, dist, n_species, nn_tolerance, max_radius, mode
):
assert finished.dtype == bool

n_nodes = len(graph.nodes.positions)
senders, receivers = graph.senders, graph.receivers

mask = jnp.isin(senders, visited) & ~jnp.isin(receivers, visited)

mask = mask & (dist < max_radius) & ~finished[senders]

# use max_radius to compute the stop probability:
s = jnp.zeros((n_nodes,))
for i in visited:
# i not finished and has no possible targets
if not finished[i] and jnp.sum((senders == i) & mask) == 0:
s = s.at[i].set(1.0)

# restrict to nearest neighbours:
if mode == "nn":
min_dist = dist[mask].min()
mask = mask & (dist < min_dist + nn_tolerance)
del min_dist
if mode == "radius":
mask = mask & (dist < max_radius)

n = jnp.zeros((n_nodes, n_species))
for focus_node in range(n_nodes):
targets = receivers[(senders == focus_node) & mask]
n = n.at[focus_node].set(
jnp.bincount(graph.nodes.species[targets], length=n_species)
)
for i in visited:
targets = receivers[(senders == i) & mask]
n = n.at[i].set(jnp.bincount(graph.nodes.species[targets], length=n_species))

if jnp.sum(n) == 0:
raise ValueError("No targets found.")

target_species_probability = n / jnp.sum(n)
# target_species_probability
# last entry is the stop probability
ts_pr = jnp.zeros((n_nodes, n_species + 1))
ts_pr = ts_pr.at[:, :n_species].set(n)
ts_pr = ts_pr.at[:, -1].set(s)
ts_pr = ts_pr / jnp.sum(ts_pr)

# pick a random focus node
# pick a random target specie (or stop)
rng, k = jax.random.split(rng)
focus_probability = _normalized_bitcount(senders[mask], n_nodes)
focus_node = jax.random.choice(k, n_nodes, p=focus_probability)
focus_node, target_specie = _sample_index(k, ts_pr)

# pick a random target
if target_specie == n_species:
# stop atom `focus_node`
new_finished = finished.at[focus_node].set(True)
sample = _into_fragment(graph, visited, focus_node, ts_pr, focus_node, finished)

return rng, visited, new_finished, sample

potential_targets = receivers[
(senders == focus_node)
& mask
& (graph.nodes.species[receivers] == target_specie)
]
assert len(potential_targets) > 0
rng, k = jax.random.split(rng)
targets = receivers[(senders == focus_node) & mask]
target_node = jax.random.choice(k, targets)
target_node = jax.random.choice(k, potential_targets)

new_visited = jnp.concatenate([visited, jnp.array([target_node])])

sample = _into_fragment(
graph,
visited,
focus_node,
target_species_probability,
target_node,
stop=False,
)
sample = _into_fragment(graph, visited, focus_node, ts_pr, target_node, finished)

return rng, new_visited, sample
return rng, new_visited, finished, sample


def _make_last_fragment(graph, n_species):
n_nodes = len(graph.nodes.positions)
return _into_fragment(
def _make_last_fragments(rng, finished, graph, n_species):
num_nodes = len(graph.nodes.positions)

ts_pr = jnp.zeros((num_nodes, n_species + 1))
ts_pr = ts_pr.at[~finished, -1].set(1.0)
ts_pr = ts_pr / jnp.sum(ts_pr)

rng, k = jax.random.split(rng)
focus_node, target_specie = _sample_index(k, ts_pr)
assert target_specie == n_species

sample = _into_fragment(
graph,
visited=jnp.arange(len(graph.nodes.positions)),
focus_node=0,
target_species_probability=jnp.zeros((n_nodes, n_species)),
target_node=0,
stop=True,
visited=jnp.arange(num_nodes),
focus_node=focus_node,
target_species_probability=ts_pr,
target_node=focus_node,
finished=finished,
)

finished = finished.at[focus_node].set(True)
return rng, finished, sample


def _into_fragment(
graph,
visited,
focus_node,
target_species_probability,
target_node,
stop,
finished,
):
pos = graph.nodes.positions
nodes = datatypes.FragmentsNodes(
positions=pos,
species=graph.nodes.species,
target_species_probs=target_species_probability,
finished=finished,
)
globals = datatypes.FragmentsGlobals(
stop=jnp.array([stop], dtype=bool), # [1]
target_species=graph.nodes.species[target_node][None], # [1]
target_positions=(pos[target_node] - pos[focus_node])[None], # [1, 3]
)
if target_node == focus_node:
# no target, focus node is stoped
globals = datatypes.FragmentsGlobals(
target_species=jnp.array([-1]), # [1]
target_positions=jnp.zeros((1, 3)), # [1, 3]
)
else:
globals = datatypes.FragmentsGlobals(
target_species=graph.nodes.species[target_node][None], # [1]
target_positions=(pos[target_node] - pos[focus_node])[None], # [1, 3]
)
graph = graph._replace(nodes=nodes, globals=globals)

if stop:
assert len(visited) == len(pos)
if len(visited) == len(pos):
return graph
else:
# put focus node at the beginning
Expand All @@ -216,6 +253,11 @@ def _normalized_bitcount(xs, n: int):
return jnp.bincount(xs, length=n) / len(xs)


def _sample_index(rng, p: jnp.ndarray) -> Tuple[int, ...]:
i = jax.random.choice(rng, jnp.arange(p.size), p=p.ravel())
return jnp.unravel_index(i, p.shape)


def subgraph(graph: jraph.GraphsTuple, nodes: jnp.ndarray) -> jraph.GraphsTuple:
"""Extract a subgraph from a graph.

Expand Down
6 changes: 6 additions & 0 deletions fragments/dispatch_fragmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ def main(
root_dir: str,
mode: str,
):
if os.path.exists(root_dir):
print("Root directory already exists.")
else:
os.makedirs(root_dir)
print(f"Created root directory {root_dir}")

qm9_data = qm9.load_qm9("qm9_data")
starts = list(range(0, len(qm9_data), chunk))
del qm9_data
Expand Down
9 changes: 3 additions & 6 deletions fragments/fragmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ def main(
"positions": tf.TensorSpec(shape=(None, 3), dtype=tf.float32),
"species": tf.TensorSpec(shape=(None,), dtype=tf.int32),
"target_species_probs": tf.TensorSpec(
shape=(None, len(atomic_numbers)), dtype=tf.float32
shape=(None, len(atomic_numbers) + 1), dtype=tf.float32
),
"finished": tf.TensorSpec(shape=(None,), dtype=tf.bool),
# edges
"senders": tf.TensorSpec(shape=(None,), dtype=tf.int32),
"receivers": tf.TensorSpec(shape=(None,), dtype=tf.int32),
# globals
"stop": tf.TensorSpec(shape=(1,), dtype=tf.bool),
"target_positions": tf.TensorSpec(shape=(1, 3), dtype=tf.float32),
"target_species": tf.TensorSpec(shape=(1,), dtype=tf.int32),
# n_node and n_edge
Expand All @@ -70,9 +70,6 @@ def generator():
"Target position is too far away from the rest of the molecule."
)
skip = True
if len(frags) == 0 or not frags[-1].globals.stop:
print("The last fragment is not a stop fragment.")
skip = True

if skip:
continue
Expand All @@ -84,9 +81,9 @@ def generator():
"target_species_probs": frag.nodes.target_species_probs.astype(
np.float32
),
"finished": frag.nodes.finished.astype(bool),
"senders": frag.senders.astype(np.int32),
"receivers": frag.receivers.astype(np.int32),
"stop": frag.globals.stop.astype(np.bool_),
"target_positions": frag.globals.target_positions.astype(
np.float32
),
Expand Down
8 changes: 5 additions & 3 deletions input_pipeline_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,9 @@ def filter_file(filename: str, start: int, end: int) -> bool:
"Could not find the correct number of molecules in the first chunk."
)

dataset_split = dataset_split.skip(num_steps_to_skip).take(num_steps_to_take)
dataset_split = dataset_split.skip(num_steps_to_skip).take(
num_steps_to_take
)

# This is usually the case.
else:
Expand Down Expand Up @@ -300,26 +302,26 @@ def _convert_to_graphstuple(graph: Dict[str, tf.Tensor]) -> jraph.GraphsTuple:
positions = graph["positions"]
species = graph["species"]
target_species_probs = graph["target_species_probs"]
finished = graph["finished"]
receivers = graph["receivers"]
senders = graph["senders"]
n_node = graph["n_node"]
n_edge = graph["n_edge"]
edges = tf.ones((tf.shape(senders)[0], 1))
stop = graph["stop"]
target_positions = graph["target_positions"]
target_species = graph["target_species"]

return jraph.GraphsTuple(
nodes=datatypes.FragmentsNodes(
positions=positions,
species=species,
finished=finished,
target_species_probs=target_species_probs,
),
edges=edges,
receivers=receivers,
senders=senders,
globals=datatypes.FragmentsGlobals(
stop=stop,
target_positions=target_positions,
target_species=target_species,
),
Expand Down