From b2f2b55f2be0fcf483e46129054ba93a0fe44f71 Mon Sep 17 00:00:00 2001 From: Neel Shah Date: Thu, 4 Dec 2025 14:22:28 -0800 Subject: [PATCH 1/4] Add enable/disable magic for showing graph tab --- src/graph_notebook/magics/graph_magic.py | 331 ++++++++++++----------- 1 file changed, 173 insertions(+), 158 deletions(-) diff --git a/src/graph_notebook/magics/graph_magic.py b/src/graph_notebook/magics/graph_magic.py index a3400e01..83b9068c 100644 --- a/src/graph_notebook/magics/graph_magic.py +++ b/src/graph_notebook/magics/graph_magic.py @@ -423,6 +423,7 @@ def __init__(self, shell): self.max_results = DEFAULT_MAX_RESULTS self.graph_notebook_vis_options = OPTIONS_DEFAULT_DIRECTED self._generate_client_from_config(self.graph_notebook_config) + self.show_graph_tab = True root_logger.setLevel(logging.CRITICAL) logger.setLevel(logging.ERROR) @@ -881,66 +882,67 @@ def sparql(self, line='', cell='', local_ns: dict = None): children.append(raw_output) titles.append('Raw') else: - if query_type in ['SELECT', 'CONSTRUCT', 'DESCRIBE']: - # TODO: Serialize other result types to SPARQL JSON so we can create table and visualization - logger.debug('creating sparql network...') - - titles.append('Table') - - sn = SPARQLNetwork(group_by_property=args.group_by, - display_property=args.display_property, - edge_display_property=args.edge_display_property, - tooltip_property=args.tooltip_property, - edge_tooltip_property=args.edge_tooltip_property, - label_max_length=args.label_max_length, - edge_label_max_length=args.edge_label_max_length, - ignore_groups=args.ignore_groups, - expand_all=args.expand_all, - group_by_raw=args.group_by_raw) - - sn.extract_prefix_declarations_from_query(cell) - try: - sn.add_results(results) - except ValueError as value_error: - logger.debug(value_error) - - logger.debug(f'number of nodes is {len(sn.graph.nodes)}') - if len(sn.graph.nodes) > 0: - self.graph_notebook_vis_options['physics']['disablePhysicsAfterInitialSimulation'] \ - = args.stop_physics - self.graph_notebook_vis_options['physics']['simulationDuration'] = args.simulation_duration - f = Force(network=sn, options=self.graph_notebook_vis_options) - titles.append('Graph') - children.append(f) - logger.debug('added sparql network to tabs') - - rows_and_columns = sparql_get_rows_and_columns(results) - if rows_and_columns is not None: - results_df = pd.DataFrame(rows_and_columns['rows']).convert_dtypes() - results_df = results_df.astype(str) - results_df = results_df.map(lambda x: encode_html_chars(x)) - results_df.insert(0, "#", range(1, len(results_df) + 1)) - for col_index, col_name in enumerate(rows_and_columns['columns']): - try: - results_df.rename({results_df.columns[col_index + 1]: col_name}, - axis='columns', - inplace=True) - except IndexError: - results_df.insert(col_index + 1, col_name, []) - - # Handling CONSTRUCT and DESCRIBE on their own because we want to maintain the previous result - # pattern of showing a tsv with each line being a result binding in addition to new ones. - if query_type == 'CONSTRUCT' or query_type == 'DESCRIBE': - lines = [] - for b in results['results']['bindings']: - lines.append( - f'{b["subject"]["value"]}\t{b["predicate"]["value"]}\t{b["object"]["value"]}') - raw_output = widgets.Output(layout=sparql_layout) - with raw_output: - html = sparql_construct_template.render(lines=lines) - display(HTML(html)) - children.append(raw_output) - titles.append('Raw') + if self.show_graph_tab: + if query_type in ['SELECT', 'CONSTRUCT', 'DESCRIBE']: + # TODO: Serialize other result types to SPARQL JSON so we can create table and visualization + logger.debug('creating sparql network...') + + titles.append('Table') + + sn = SPARQLNetwork(group_by_property=args.group_by, + display_property=args.display_property, + edge_display_property=args.edge_display_property, + tooltip_property=args.tooltip_property, + edge_tooltip_property=args.edge_tooltip_property, + label_max_length=args.label_max_length, + edge_label_max_length=args.edge_label_max_length, + ignore_groups=args.ignore_groups, + expand_all=args.expand_all, + group_by_raw=args.group_by_raw) + + sn.extract_prefix_declarations_from_query(cell) + try: + sn.add_results(results) + except ValueError as value_error: + logger.debug(value_error) + + logger.debug(f'number of nodes is {len(sn.graph.nodes)}') + if len(sn.graph.nodes) > 0: + self.graph_notebook_vis_options['physics']['disablePhysicsAfterInitialSimulation'] \ + = args.stop_physics + self.graph_notebook_vis_options['physics']['simulationDuration'] = args.simulation_duration + f = Force(network=sn, options=self.graph_notebook_vis_options) + titles.append('Graph') + children.append(f) + logger.debug('added sparql network to tabs') + + rows_and_columns = sparql_get_rows_and_columns(results) + if rows_and_columns is not None: + results_df = pd.DataFrame(rows_and_columns['rows']).convert_dtypes() + results_df = results_df.astype(str) + results_df = results_df.map(lambda x: encode_html_chars(x)) + results_df.insert(0, "#", range(1, len(results_df) + 1)) + for col_index, col_name in enumerate(rows_and_columns['columns']): + try: + results_df.rename({results_df.columns[col_index + 1]: col_name}, + axis='columns', + inplace=True) + except IndexError: + results_df.insert(col_index + 1, col_name, []) + + # Handling CONSTRUCT and DESCRIBE on their own because we want to maintain the previous result + # pattern of showing a tsv with each line being a result binding in addition to new ones. + if query_type == 'CONSTRUCT' or query_type == 'DESCRIBE': + lines = [] + for b in results['results']['bindings']: + lines.append( + f'{b["subject"]["value"]}\t{b["predicate"]["value"]}\t{b["object"]["value"]}') + raw_output = widgets.Output(layout=sparql_layout) + with raw_output: + html = sparql_construct_template.render(lines=lines) + display(HTML(html)) + children.append(raw_output) + titles.append('Raw') json_output = widgets.Output(layout=sparql_layout) with json_output: @@ -1343,57 +1345,58 @@ def gremlin(self, line, cell, local_ns: dict = None): titles.append('Console') gremlin_network = None - try: - logger.debug(f'groupby: {args.group_by}') - logger.debug(f'display_property: {args.display_property}') - logger.debug(f'edge_display_property: {args.edge_display_property}') - logger.debug(f'label_max_length: {args.label_max_length}') - logger.debug(f'ignore_groups: {args.ignore_groups}') - gn = GremlinNetwork(group_by_property=args.group_by, - display_property=args.display_property, - group_by_raw=args.group_by_raw, - group_by_depth=args.group_by_depth, - edge_display_property=args.edge_display_property, - tooltip_property=args.tooltip_property, - edge_tooltip_property=args.edge_tooltip_property, - label_max_length=args.label_max_length, - edge_label_max_length=args.edge_label_max_length, - ignore_groups=args.ignore_groups, - using_http=using_http) - - if using_http and 'path()' in cell and query_res and isinstance(query_res, list): - first_path = query_res[0] - if isinstance(first_path, dict) and first_path.keys() == {'labels', 'objects'}: - query_res_to_path_type = [] - for path in query_res: - new_path_list = path['objects'] - new_path = Path(labels=[], objects=new_path_list) - query_res_to_path_type.append(new_path) - query_res = query_res_to_path_type - - if args.path_pattern == '': - gn.add_results(query_res, is_http=using_http) - else: - pattern = parse_pattern_list_str(args.path_pattern) - gn.add_results_with_pattern(query_res, pattern) - gremlin_network = gn - logger.debug(f'number of nodes is {len(gn.graph.nodes)}') - except ValueError as value_error: - logger.debug( - f'Unable to create graph network from result due to error: {value_error}. ' - f'Skipping from result set.') - if gremlin_network and len(gremlin_network.graph.nodes) > 0: + if self.show_graph_tab: try: - self.graph_notebook_vis_options['physics']['disablePhysicsAfterInitialSimulation'] \ - = args.stop_physics - self.graph_notebook_vis_options['physics']['simulationDuration'] = args.simulation_duration - f = Force(network=gremlin_network, options=self.graph_notebook_vis_options) - titles.append('Graph') - children.append(f) - logger.debug('added gremlin network to tabs') - except Exception as force_error: + logger.debug(f'groupby: {args.group_by}') + logger.debug(f'display_property: {args.display_property}') + logger.debug(f'edge_display_property: {args.edge_display_property}') + logger.debug(f'label_max_length: {args.label_max_length}') + logger.debug(f'ignore_groups: {args.ignore_groups}') + gn = GremlinNetwork(group_by_property=args.group_by, + display_property=args.display_property, + group_by_raw=args.group_by_raw, + group_by_depth=args.group_by_depth, + edge_display_property=args.edge_display_property, + tooltip_property=args.tooltip_property, + edge_tooltip_property=args.edge_tooltip_property, + label_max_length=args.label_max_length, + edge_label_max_length=args.edge_label_max_length, + ignore_groups=args.ignore_groups, + using_http=using_http) + + if using_http and 'path()' in cell and query_res and isinstance(query_res, list): + first_path = query_res[0] + if isinstance(first_path, dict) and first_path.keys() == {'labels', 'objects'}: + query_res_to_path_type = [] + for path in query_res: + new_path_list = path['objects'] + new_path = Path(labels=[], objects=new_path_list) + query_res_to_path_type.append(new_path) + query_res = query_res_to_path_type + + if args.path_pattern == '': + gn.add_results(query_res, is_http=using_http) + else: + pattern = parse_pattern_list_str(args.path_pattern) + gn.add_results_with_pattern(query_res, pattern) + gremlin_network = gn + logger.debug(f'number of nodes is {len(gn.graph.nodes)}') + except ValueError as value_error: logger.debug( - f'Unable to render visualization from graph network due to error: {force_error}. Skipping.') + f'Unable to create graph network from result due to error: {value_error}. ' + f'Skipping from result set.') + if gremlin_network and len(gremlin_network.graph.nodes) > 0: + try: + self.graph_notebook_vis_options['physics']['disablePhysicsAfterInitialSimulation'] \ + = args.stop_physics + self.graph_notebook_vis_options['physics']['simulationDuration'] = args.simulation_duration + f = Force(network=gremlin_network, options=self.graph_notebook_vis_options) + titles.append('Graph') + children.append(f) + logger.debug('added gremlin network to tabs') + except Exception as force_error: + logger.debug( + f'Unable to render visualization from graph network due to error: {force_error}. Skipping.') # Check if we can access the CDNs required by itables library. # If not, then render our own HTML template. @@ -3474,6 +3477,14 @@ def enable_debug(self, line): def disable_debug(self, line): logger.setLevel(logging.ERROR) root_logger.setLevel(logging.CRITICAL) + + @line_magic + def enable_graph_tab(self, line): + self.show_graph_tab = True + + @line_magic + def disable_graph_tab(self, line): + self.show_graph_tab = False @line_magic @needs_local_scope @@ -3726,28 +3737,30 @@ def handle_opencypher_query(self, line, cell, local_ns): results_df, has_results = oc_results_df(res, res_format) if has_results: titles.append('Console') - try: - gn = OCNetwork(group_by_property=args.group_by, display_property=args.display_property, - group_by_raw=args.group_by_raw, - group_by_depth=args.group_by_depth, - edge_display_property=args.edge_display_property, - tooltip_property=args.tooltip_property, - edge_tooltip_property=args.edge_tooltip_property, - label_max_length=args.label_max_length, - edge_label_max_length=args.rel_label_max_length, - ignore_groups=args.ignore_groups) - gn.add_results(res) - logger.debug(f'number of nodes is {len(gn.graph.nodes)}') - if len(gn.graph.nodes) > 0: - self.graph_notebook_vis_options['physics']['disablePhysicsAfterInitialSimulation'] \ - = args.stop_physics - self.graph_notebook_vis_options['physics']['simulationDuration'] = args.simulation_duration - force_graph_output = Force(network=gn, options=self.graph_notebook_vis_options) - titles.append('Graph') - children.append(force_graph_output) - except (TypeError, ValueError) as network_creation_error: - logger.debug(f'Unable to create network from result. Skipping from result set: {res}') - logger.debug(f'Error: {network_creation_error}') + + if self.show_graph_tab: + try: + gn = OCNetwork(group_by_property=args.group_by, display_property=args.display_property, + group_by_raw=args.group_by_raw, + group_by_depth=args.group_by_depth, + edge_display_property=args.edge_display_property, + tooltip_property=args.tooltip_property, + edge_tooltip_property=args.edge_tooltip_property, + label_max_length=args.label_max_length, + edge_label_max_length=args.rel_label_max_length, + ignore_groups=args.ignore_groups) + gn.add_results(res) + logger.debug(f'number of nodes is {len(gn.graph.nodes)}') + if len(gn.graph.nodes) > 0: + self.graph_notebook_vis_options['physics']['disablePhysicsAfterInitialSimulation'] \ + = args.stop_physics + self.graph_notebook_vis_options['physics']['simulationDuration'] = args.simulation_duration + force_graph_output = Force(network=gn, options=self.graph_notebook_vis_options) + titles.append('Graph') + children.append(force_graph_output) + except (TypeError, ValueError) as network_creation_error: + logger.debug(f'Unable to create network from result. Skipping from result set: {res}') + logger.debug(f'Error: {network_creation_error}') elif args.mode == 'bolt': res_format = 'bolt' @@ -3764,33 +3777,35 @@ def handle_opencypher_query(self, line, cell, local_ns): results_df, has_results = oc_results_df(res, res_format) if has_results: titles.append('Console') - # Create graph visualization for bolt response - try: - # Wrap bolt response in expected format - # Required because the graph visualizer need the data to be present in a certain format - transformed_res = {"results": res} if isinstance(res, list) else {"results": []} - - gn = OCNetwork(group_by_property=args.group_by, display_property=args.display_property, - group_by_raw=args.group_by_raw, - group_by_depth=args.group_by_depth, - edge_display_property=args.edge_display_property, - tooltip_property=args.tooltip_property, - edge_tooltip_property=args.edge_tooltip_property, - label_max_length=args.label_max_length, - edge_label_max_length=args.rel_label_max_length, - ignore_groups=args.ignore_groups) - gn.add_results(transformed_res) - logger.debug(f'number of nodes is {len(gn.graph.nodes)}') - if len(gn.graph.nodes) > 0: - self.graph_notebook_vis_options['physics']['disablePhysicsAfterInitialSimulation'] \ - = args.stop_physics - self.graph_notebook_vis_options['physics']['simulationDuration'] = args.simulation_duration - force_graph_output = Force(network=gn, options=self.graph_notebook_vis_options) - titles.append('Graph') - children.append(force_graph_output) - except (TypeError, ValueError) as network_creation_error: - logger.debug(f'Unable to create network from bolt result. Skipping from result set: {res}') - logger.debug(f'Error: {network_creation_error}') + + if self.show_graph_tab: + # Create graph visualization for bolt response + try: + # Wrap bolt response in expected format + # Required because the graph visualizer need the data to be present in a certain format + transformed_res = {"results": res} if isinstance(res, list) else {"results": []} + + gn = OCNetwork(group_by_property=args.group_by, display_property=args.display_property, + group_by_raw=args.group_by_raw, + group_by_depth=args.group_by_depth, + edge_display_property=args.edge_display_property, + tooltip_property=args.tooltip_property, + edge_tooltip_property=args.edge_tooltip_property, + label_max_length=args.label_max_length, + edge_label_max_length=args.rel_label_max_length, + ignore_groups=args.ignore_groups) + gn.add_results(transformed_res) + logger.debug(f'number of nodes is {len(gn.graph.nodes)}') + if len(gn.graph.nodes) > 0: + self.graph_notebook_vis_options['physics']['disablePhysicsAfterInitialSimulation'] \ + = args.stop_physics + self.graph_notebook_vis_options['physics']['simulationDuration'] = args.simulation_duration + force_graph_output = Force(network=gn, options=self.graph_notebook_vis_options) + titles.append('Graph') + children.append(force_graph_output) + except (TypeError, ValueError) as network_creation_error: + logger.debug(f'Unable to create network from bolt result. Skipping from result set: {res}') + logger.debug(f'Error: {network_creation_error}') if not args.silent: if args.mode != 'explain': From 383fe1e26c0bd064fb8c361e873eb43be07e9031 Mon Sep 17 00:00:00 2001 From: Neel Shah Date: Thu, 4 Dec 2025 14:22:28 -0800 Subject: [PATCH 2/4] Add enable/disable magic for showing graph tab --- src/graph_notebook/network/Network.py | 231 +++++++++++++------------- 1 file changed, 117 insertions(+), 114 deletions(-) diff --git a/src/graph_notebook/network/Network.py b/src/graph_notebook/network/Network.py index bbbead0b..48675c44 100644 --- a/src/graph_notebook/network/Network.py +++ b/src/graph_notebook/network/Network.py @@ -1,114 +1,117 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import json - -from networkx import MultiDiGraph -from networkx.readwrite import json_graph - -ERROR_EDGE_NOT_FOUND = ValueError("Edge was not found on network graph") -ERROR_INVALID_DATA = ValueError("Data must be a dict") - - -class Network: - """ - Network wraps a Networkx MultiDiGraph and provides some utilities - to add nodes and edges to the graph. For use each language meant to use it, - the Network class will be extended to ensure that we are adding the data needed to ensure that - we maintain the properties inside each node and edge appropriately. - """ - - def __init__(self, graph: MultiDiGraph = None): - if graph is None: - graph = MultiDiGraph() - self.graph = graph - - def add_node_property(self, node_id: str, key: str, value: str): - """ - updates the "properties" key on the given :param node_id. For instance, if key=foo, and value=bar, - then the given node would now be guaranteed to have the entry node['properties']['foo'] = bar - :param node_id: id of the node to update - :param key: the key to update under this nodes' properties dict - :param value: the value to set - """ - node = self.graph.nodes.get(node_id) - if node is None: - node = self.graph.add_node(node_id) - - if 'properties' not in node: - node['properties'] = {key: value} - else: - node['properties'][key] = value - - def add_node(self, node_id: str, data=None): - if data is None: - data = {} - self.graph.add_node(node_id, **data) - - def add_edge(self, from_id: str, to_id: str, edge_id: str, label: str, data: dict = None): - if data is None: - data = {} - - data['label'] = label - self.graph.add_edge(from_id, to_id, edge_id, **data) - - def add_node_data(self, node_id: str, data: dict): - """ - overrides the keys on a node with the data found in :param data - :param node_id: the id of the node to update - :param data: key-value dictionary to update node with - """ - - if type(data) is not dict: - raise ERROR_INVALID_DATA - - node = self.graph.nodes.get(node_id) - if node is None: - self.add_node(node_id, data) - return - - for key in data: - node[key] = data[key] - - def add_edge_data(self, from_id: str, to_id: str, edge_id, data: dict): - if not self.graph.has_edge(from_id, to_id, edge_id): - raise ERROR_EDGE_NOT_FOUND - - if type(data) is not dict: - raise ERROR_INVALID_DATA - - edge = self.graph.edges[from_id, to_id, edge_id] - for key in data: - edge[key] = data[key] - - def add_results(self, results): - """ - base method to be overridden by implementations to add results. - For SPARQL, these results are a dict with bindings, for Gremlin, they are paths - :param results: - :return: - """ - pass - - def to_json(self) -> dict: - try: - # NetworkX 2.6+ - graph_data = json_graph.node_link_data(self.graph, edges="links") - except TypeError: - # NetworkX < 2.6 - graph_data = json_graph.node_link_data(self.graph) - return {'graph': graph_data} - - -def network_to_json(network: Network) -> str: - return json.dumps(network.to_json()) - - -def network_from_json(raw) -> Network: - data = json.loads(raw) - network = Network() - if 'graph' in data: - network.graph = json_graph.node_link_graph(data['graph'], directed=True) - return network +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" + +import json + +from networkx import MultiDiGraph +from networkx.readwrite import json_graph + +ERROR_EDGE_NOT_FOUND = ValueError("Edge was not found on network graph") +ERROR_INVALID_DATA = ValueError("Data must be a dict") + + +class Network: + """ + Network wraps a Networkx MultiDiGraph and provides some utilities + to add nodes and edges to the graph. For use each language meant to use it, + the Network class will be extended to ensure that we are adding the data needed to ensure that + we maintain the properties inside each node and edge appropriately. + """ + + def __init__(self, graph: MultiDiGraph = None): + if graph is None: + graph = MultiDiGraph() + self.graph = graph + + def add_node_property(self, node_id: str, key: str, value: str): + """ + updates the "properties" key on the given :param node_id. For instance, if key=foo, and value=bar, + then the given node would now be guaranteed to have the entry node['properties']['foo'] = bar + :param node_id: id of the node to update + :param key: the key to update under this nodes' properties dict + :param value: the value to set + """ + node = self.graph.nodes.get(node_id) + if node is None: + node = self.graph.add_node(node_id) + + if 'properties' not in node: + node['properties'] = {key: value} + else: + node['properties'][key] = value + + def add_node(self, node_id: str, data=None): + if data is None: + data = {} + self.graph.add_node(node_id, **data) + + def add_edge(self, from_id: str, to_id: str, edge_id: str, label: str, data: dict = None): + if data is None: + data = {} + + data['label'] = label + self.graph.add_edge(from_id, to_id, edge_id, **data) + + def add_node_data(self, node_id: str, data: dict): + """ + overrides the keys on a node with the data found in :param data + :param node_id: the id of the node to update + :param data: key-value dictionary to update node with + """ + + if type(data) is not dict: + raise ERROR_INVALID_DATA + + node = self.graph.nodes.get(node_id) + if node is None: + self.add_node(node_id, data) + return + + for key in data: + node[key] = data[key] + + def add_edge_data(self, from_id: str, to_id: str, edge_id, data: dict): + if not self.graph.has_edge(from_id, to_id, edge_id): + raise ERROR_EDGE_NOT_FOUND + + if type(data) is not dict: + raise ERROR_INVALID_DATA + + edge = self.graph.edges[from_id, to_id, edge_id] + for key in data: + edge[key] = data[key] + + def add_results(self, results): + """ + base method to be overridden by implementations to add results. + For SPARQL, these results are a dict with bindings, for Gremlin, they are paths + :param results: + :return: + """ + pass + + def to_json(self) -> dict: + try: + # NetworkX 2.6+ + graph_data = json_graph.node_link_data(self.graph, edges="links") + except TypeError: + # NetworkX < 2.6 + graph_data = json_graph.node_link_data(self.graph) + return {'graph': graph_data} + + +def network_to_json(network: Network) -> str: + return json.dumps(network.to_json()) + + +def network_from_json(raw) -> Network: + data = json.loads(raw) + network = Network() + if 'graph' in data: + try: + network.graph = json_graph.node_link_graph(data['graph'], directed=True, edges="links") + except: + network.graph = json_graph.node_link_graph(data['graph'], directed=True) + return network From 422877966fce6f04130be87f2395d850a0ef11cf Mon Sep 17 00:00:00 2001 From: Neel Shah Date: Thu, 4 Dec 2025 15:18:29 -0800 Subject: [PATCH 3/4] Revert "Add enable/disable magic for showing graph tab" This reverts commit 383fe1e26c0bd064fb8c361e873eb43be07e9031. --- src/graph_notebook/network/Network.py | 231 +++++++++++++------------- 1 file changed, 114 insertions(+), 117 deletions(-) diff --git a/src/graph_notebook/network/Network.py b/src/graph_notebook/network/Network.py index 48675c44..bbbead0b 100644 --- a/src/graph_notebook/network/Network.py +++ b/src/graph_notebook/network/Network.py @@ -1,117 +1,114 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -SPDX-License-Identifier: Apache-2.0 -""" - -import json - -from networkx import MultiDiGraph -from networkx.readwrite import json_graph - -ERROR_EDGE_NOT_FOUND = ValueError("Edge was not found on network graph") -ERROR_INVALID_DATA = ValueError("Data must be a dict") - - -class Network: - """ - Network wraps a Networkx MultiDiGraph and provides some utilities - to add nodes and edges to the graph. For use each language meant to use it, - the Network class will be extended to ensure that we are adding the data needed to ensure that - we maintain the properties inside each node and edge appropriately. - """ - - def __init__(self, graph: MultiDiGraph = None): - if graph is None: - graph = MultiDiGraph() - self.graph = graph - - def add_node_property(self, node_id: str, key: str, value: str): - """ - updates the "properties" key on the given :param node_id. For instance, if key=foo, and value=bar, - then the given node would now be guaranteed to have the entry node['properties']['foo'] = bar - :param node_id: id of the node to update - :param key: the key to update under this nodes' properties dict - :param value: the value to set - """ - node = self.graph.nodes.get(node_id) - if node is None: - node = self.graph.add_node(node_id) - - if 'properties' not in node: - node['properties'] = {key: value} - else: - node['properties'][key] = value - - def add_node(self, node_id: str, data=None): - if data is None: - data = {} - self.graph.add_node(node_id, **data) - - def add_edge(self, from_id: str, to_id: str, edge_id: str, label: str, data: dict = None): - if data is None: - data = {} - - data['label'] = label - self.graph.add_edge(from_id, to_id, edge_id, **data) - - def add_node_data(self, node_id: str, data: dict): - """ - overrides the keys on a node with the data found in :param data - :param node_id: the id of the node to update - :param data: key-value dictionary to update node with - """ - - if type(data) is not dict: - raise ERROR_INVALID_DATA - - node = self.graph.nodes.get(node_id) - if node is None: - self.add_node(node_id, data) - return - - for key in data: - node[key] = data[key] - - def add_edge_data(self, from_id: str, to_id: str, edge_id, data: dict): - if not self.graph.has_edge(from_id, to_id, edge_id): - raise ERROR_EDGE_NOT_FOUND - - if type(data) is not dict: - raise ERROR_INVALID_DATA - - edge = self.graph.edges[from_id, to_id, edge_id] - for key in data: - edge[key] = data[key] - - def add_results(self, results): - """ - base method to be overridden by implementations to add results. - For SPARQL, these results are a dict with bindings, for Gremlin, they are paths - :param results: - :return: - """ - pass - - def to_json(self) -> dict: - try: - # NetworkX 2.6+ - graph_data = json_graph.node_link_data(self.graph, edges="links") - except TypeError: - # NetworkX < 2.6 - graph_data = json_graph.node_link_data(self.graph) - return {'graph': graph_data} - - -def network_to_json(network: Network) -> str: - return json.dumps(network.to_json()) - - -def network_from_json(raw) -> Network: - data = json.loads(raw) - network = Network() - if 'graph' in data: - try: - network.graph = json_graph.node_link_graph(data['graph'], directed=True, edges="links") - except: - network.graph = json_graph.node_link_graph(data['graph'], directed=True) - return network +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +SPDX-License-Identifier: Apache-2.0 +""" + +import json + +from networkx import MultiDiGraph +from networkx.readwrite import json_graph + +ERROR_EDGE_NOT_FOUND = ValueError("Edge was not found on network graph") +ERROR_INVALID_DATA = ValueError("Data must be a dict") + + +class Network: + """ + Network wraps a Networkx MultiDiGraph and provides some utilities + to add nodes and edges to the graph. For use each language meant to use it, + the Network class will be extended to ensure that we are adding the data needed to ensure that + we maintain the properties inside each node and edge appropriately. + """ + + def __init__(self, graph: MultiDiGraph = None): + if graph is None: + graph = MultiDiGraph() + self.graph = graph + + def add_node_property(self, node_id: str, key: str, value: str): + """ + updates the "properties" key on the given :param node_id. For instance, if key=foo, and value=bar, + then the given node would now be guaranteed to have the entry node['properties']['foo'] = bar + :param node_id: id of the node to update + :param key: the key to update under this nodes' properties dict + :param value: the value to set + """ + node = self.graph.nodes.get(node_id) + if node is None: + node = self.graph.add_node(node_id) + + if 'properties' not in node: + node['properties'] = {key: value} + else: + node['properties'][key] = value + + def add_node(self, node_id: str, data=None): + if data is None: + data = {} + self.graph.add_node(node_id, **data) + + def add_edge(self, from_id: str, to_id: str, edge_id: str, label: str, data: dict = None): + if data is None: + data = {} + + data['label'] = label + self.graph.add_edge(from_id, to_id, edge_id, **data) + + def add_node_data(self, node_id: str, data: dict): + """ + overrides the keys on a node with the data found in :param data + :param node_id: the id of the node to update + :param data: key-value dictionary to update node with + """ + + if type(data) is not dict: + raise ERROR_INVALID_DATA + + node = self.graph.nodes.get(node_id) + if node is None: + self.add_node(node_id, data) + return + + for key in data: + node[key] = data[key] + + def add_edge_data(self, from_id: str, to_id: str, edge_id, data: dict): + if not self.graph.has_edge(from_id, to_id, edge_id): + raise ERROR_EDGE_NOT_FOUND + + if type(data) is not dict: + raise ERROR_INVALID_DATA + + edge = self.graph.edges[from_id, to_id, edge_id] + for key in data: + edge[key] = data[key] + + def add_results(self, results): + """ + base method to be overridden by implementations to add results. + For SPARQL, these results are a dict with bindings, for Gremlin, they are paths + :param results: + :return: + """ + pass + + def to_json(self) -> dict: + try: + # NetworkX 2.6+ + graph_data = json_graph.node_link_data(self.graph, edges="links") + except TypeError: + # NetworkX < 2.6 + graph_data = json_graph.node_link_data(self.graph) + return {'graph': graph_data} + + +def network_to_json(network: Network) -> str: + return json.dumps(network.to_json()) + + +def network_from_json(raw) -> Network: + data = json.loads(raw) + network = Network() + if 'graph' in data: + network.graph = json_graph.node_link_graph(data['graph'], directed=True) + return network From 1c1391e652331a42937d156617f8718e3af53d11 Mon Sep 17 00:00:00 2001 From: Neel Shah Date: Thu, 4 Dec 2025 15:20:02 -0800 Subject: [PATCH 4/4] Add fallback for NetworkX upgrade --- src/graph_notebook/network/Network.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/graph_notebook/network/Network.py b/src/graph_notebook/network/Network.py index bbbead0b..4005b6da 100644 --- a/src/graph_notebook/network/Network.py +++ b/src/graph_notebook/network/Network.py @@ -110,5 +110,8 @@ def network_from_json(raw) -> Network: data = json.loads(raw) network = Network() if 'graph' in data: - network.graph = json_graph.node_link_graph(data['graph'], directed=True) + try: + network.graph = json_graph.node_link_graph(data['graph'], directed=True, edges="links") + except: + network.graph = json_graph.node_link_graph(data['graph'], directed=True) return network