Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

## New features

- Add convenience method `add_data` and `remove_data` to `GraphWidget`.

## Bug fixes

- Fixed a bug with the theme detection inn VSCode.
Expand Down
695 changes: 327 additions & 368 deletions examples/getting-started.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions justfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ py-sync:
cd python-wrapper && uv sync --group dev --group docs --group notebook --extra pandas --extra neo4j --extra gds --extra snowflake

py-style:
just py-sync
./scripts/makestyle.sh && ./scripts/checkstyle.sh

py-test:
Expand Down
2 changes: 1 addition & 1 deletion python-wrapper/src/neo4j_viz/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .node_size import RealNumber
from .options import CaptionAlignment

NodeIdType = Union[str, int]
NodeIdType = str | int


def create_aliases(field_name: str) -> AliasChoices:
Expand Down
4 changes: 3 additions & 1 deletion python-wrapper/src/neo4j_viz/relationship.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from .options import CaptionAlignment

RelationshipIdType = str | int


def create_aliases(field_name: str) -> AliasChoices:
valid_names = [field_name]
Expand Down Expand Up @@ -43,7 +45,7 @@ class Relationship(
"""

#: Unique identifier for the relationship
id: Union[str, int] = Field(
id: RelationshipIdType = Field(
default_factory=lambda: uuid4().hex, description="Unique identifier for the relationship"
)
#: Node ID where the relationship points from
Expand Down
9 changes: 6 additions & 3 deletions python-wrapper/src/neo4j_viz/visualization_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _build_render_options(
self,
layout: Layout | None,
layout_options: dict[str, Any] | LayoutOptions | None,
renderer: Renderer,
renderer: Renderer | str,
pan_position: tuple[float, float] | None,
initial_zoom: float | None,
min_zoom: float,
Expand All @@ -105,6 +105,9 @@ def _build_render_options(
"overriding `max_allowed_nodes`, but rendering could then take a long time"
)

if isinstance(renderer, str):
renderer = Renderer(renderer)

Renderer.check(renderer, num_nodes)

if not layout:
Expand Down Expand Up @@ -133,7 +136,7 @@ def render(
self,
layout: Layout | None = None,
layout_options: dict[str, Any] | LayoutOptions | None = None,
renderer: Renderer = Renderer.CANVAS,
renderer: Renderer | str = Renderer.CANVAS,
width: str = "100%",
height: str = "600px",
pan_position: tuple[float, float] | None = None,
Expand Down Expand Up @@ -207,7 +210,7 @@ def render_widget(
self,
layout: Layout | None = None,
layout_options: dict[str, Any] | LayoutOptions | None = None,
renderer: Renderer = Renderer.CANVAS,
renderer: Renderer | str = Renderer.CANVAS,
width: str = "100%",
height: str = "600px",
pan_position: tuple[float, float] | None = None,
Expand Down
73 changes: 71 additions & 2 deletions python-wrapper/src/neo4j_viz/widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import anywidget
import traitlets

from .node import Node
from .node import Node, NodeIdType
from .options import RenderOptions
from .relationship import Relationship
from .relationship import Relationship, RelationshipIdType


def _serialize_entity(entity: Union[Node, Relationship]) -> dict[str, Any]:
Expand Down Expand Up @@ -79,3 +79,72 @@ def from_graph_data(
options=options.to_js_options() if options else {},
theme=theme,
)

def add_data(
self, nodes: Node | list[Node] | None = None, relationships: Relationship | list[Relationship] | None = None
) -> None:
"""
Add nodes or relationships to the graph widget.

Parameters
-----------
nodes:
Nodes to add to the graph widget.
relationships:
Relationships to add to the graph widget.
"""
if isinstance(nodes, Node):
nodes = [nodes]
if isinstance(relationships, Relationship):
relationships = [relationships]

if nodes:
self.nodes = self.nodes + [_serialize_entity(n) for n in nodes]
if relationships:
self.relationships = self.relationships + [_serialize_entity(r) for r in relationships]

def remove_data(
self,
nodes: Node | list[Node | NodeIdType] | NodeIdType | None = None,
relationships: Relationship | list[Relationship | RelationshipIdType] | RelationshipIdType | None = None,
) -> None:
"""
Remove nodes or relationships from the graph widget.

Parameters
-----------
nodes:
Nodes to remove from the graph widget.
relationships:
Relationships to remove from the graph widget.
"""
if isinstance(nodes, Node):
node_ids_to_remove = {nodes.id}
elif isinstance(nodes, NodeIdType):
node_ids_to_remove = {nodes}
elif nodes is None:
node_ids_to_remove = set()
else:
node_ids_to_remove = {n.id if isinstance(n, Node) else n for n in nodes}

if isinstance(relationships, Relationship):
rel_ids_to_remove = {relationships.id}
elif isinstance(relationships, RelationshipIdType):
rel_ids_to_remove = {relationships}
elif relationships is None:
rel_ids_to_remove = set()
else:
rel_ids_to_remove = {r.id if isinstance(r, Relationship) else r for r in relationships}

if node_ids_to_remove:
self.nodes = [n for n in self.nodes if n["id"] not in node_ids_to_remove]

def keep_rel(r: dict[str, Any]) -> bool:
return (
r["id"] not in rel_ids_to_remove
and r["from"] not in node_ids_to_remove
and r["to"] not in node_ids_to_remove
)

if rel_ids_to_remove:
self.relationships = [r for r in self.relationships if keep_rel(r)]
27 changes: 27 additions & 0 deletions python-wrapper/tests/test_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,33 @@ def test_replace_all_data(self) -> None:
assert len(widget.relationships) == 2
assert widget.nodes[0]["id"] == "x1"

def test_add_data(self) -> None:
"""Test adding new data to existing graph."""
nodes = [Node(id="n1"), Node(id="n2")]
rels = [Relationship(source="n1", target="n2")]
widget = GraphWidget.from_graph_data(nodes, rels)

widget.add_data(Node(id="x1"), Relationship(source="x1", target="x2"))

assert len(widget.nodes) == 3
assert len(widget.relationships) == 2

def test_remove_data(self) -> None:
"""Test removing data from the graph."""
node_1 = Node(id="n1")
nodes = [node_1, Node(id="n2"), Node(id="n3")]
rels = [
Relationship(source="n1", target="n2"),
Relationship(id=42, source="n2", target="n1"),
Relationship(source="n2", target="n1"), # detach delete
Relationship(id=43, source="n3", target="n3"),
]
widget = GraphWidget.from_graph_data(nodes, rels)

widget.remove_data(nodes=[node_1, "n2"], relationships=[rels[0], "42"])
assert {n["id"] for n in widget.nodes} == {"n3"}
assert {r["id"] for r in widget.relationships} == {"43"}


render_widget_cases = {
"default": {},
Expand Down
Loading