Skip to content

Instantly share code, notes, and snippets.

@rjurney
Created January 27, 2026 18:54
Show Gist options
  • Select an option

  • Save rjurney/a29783ee7c6c4b191d54f049d96d01c6 to your computer and use it in GitHub Desktop.

Select an option

Save rjurney/a29783ee7c6c4b191d54f049d96d01c6 to your computer and use it in GitHub Desktop.
GEXFWriter - make Gephi visualizations of GraphFrames
import typing
import xml.etree.ElementTree as ET
from urllib.parse import urlparse
import boto3
import pandas as pd
import pyspark.sql
from bs4 import BeautifulSoup
from moto import mock_s3
from mypy_boto3_s3.type_defs import CreateBucketConfigurationTypeDef
from typeguard import typechecked
class DataFrameFormatException(Exception):
"""DataFrameFormatException thrown when encountering unfamiliar DataFrames, neither pandas or pyspark.sql"""
pass
class InvalidAttributeTypeException(Exception):
"""InvalidAttributeTypeException raised when an unfamiliar class is set as a node/edge attribute"""
pass
class MissingNodePropertyException(Exception):
"""MissingNodePropertyException raised when a required node property is missing: id, label"""
pass
class MissingEdgePropertyException(Exception):
"""MissingEdgePropertyException raised when a requires edge property is missing: source, target"""
pass
class InvalidS3URI(Exception):
"""InvalidS3URI raised when an invalid S3 url is parsed by GEXFWriter"""
class MissingNodeException(Exception):
"""MissingNodeException raised when an edge is created without an associated node"""
class GEXFWriter:
"""GEXFWriter Given a set of nodes and edges from a GraphFrames motif, generate a GEXF file for Gephi
An example of valid GEXF XML is:
<?xml version="1.0" encoding="UTF-8"?>
<gexf xmlns="http://www.gexf.net/1.2draft" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://www.gexf.net/1.2draft http://www.gexf.net/1.2draft/gexf.xsd" version="1.2">
<meta lastmodifieddate="2009-03-20">
<creator>Gephi.org</creator>
<description>A Web network</description>
</meta>
<graph defaultedgetype="directed">
<attributes class="node">
<attribute id="0" title="url" type="string"/>
<attribute id="1" title="indegree" type="float"/>
<attribute id="2" title="frog" type="boolean">
<default>true</default>
</attribute>
</attributes>
<attributes class="edge">
<attribute id="0" title="relationship" type="string"/>
</attributes>
<nodes>
<node id="0" label="Gephi">
<attvalues>
<attvalue for="0" value="http://gephi.org"/>
<attvalue for="1" value="1"/>
</attvalues>
</node>
...
</nodes>
<edges>
<edge id="0" source="0" target="1">
<attvalues>
<attvalue for="0" value="link"/>
</attvalues>
</edge>
...
</edges>
</graph>
</gexf>
"""
gexf_template = """<?xml version="1.0" encoding="UTF-8"?>
<gexf>
<meta lastmodifieddate="2021-09-20">
<creator>Deep Discovery, Inc.</creator>
<description></description>
</meta>
<attributes class="node"></attributes>
<attributes class="edge"></attributes>
<graph mode="static" defaultedgetype="directed">
<nodes></nodes>
<edges></edges>
</graph>
</gexf>"""
@typechecked
def __init__(
self,
nodes_df: typing.Union[pyspark.sql.DataFrame, pd.DataFrame],
edges_df: typing.Union[pyspark.sql.DataFrame, pd.DataFrame],
description: str = "Network Visualization",
region: str = "us-west-2",
) -> None:
"""__init__ Create an object that can parse a pair of node/edge [pandas or pyspark] DataFrames and write valid GEXF XML
Parameters
----------
nodes_df : pyspark.sql.DataFrame
A list of nodes from a GraphFrames motif
edges_df : pyspark.sql.DataFrame
A list of edges from a GraphFrames motif
description : str, optional
The GEXF file description or the name of your visualization, by default "Network Visualization"
region: str, optional
The S3 region to save the file to via S3, by default "us-west-2"
"""
self.description = description
if isinstance(nodes_df, pyspark.sql.DataFrame):
self.local_nodes = nodes_df.rdd.map(lambda x: x.asDict()).collect()
elif isinstance(nodes_df, pd.DataFrame):
self.local_nodes = nodes_df.to_dict("records")
else:
raise DataFrameFormatException("Unfamiliar nodes DataFrame encountered!")
if isinstance(edges_df, pyspark.sql.DataFrame):
self.local_edges = edges_df.rdd.map(lambda x: x.asDict()).collect()
elif isinstance(edges_df, pd.DataFrame):
self.local_edges = edges_df.to_dict("records")
else:
raise DataFrameFormatException("Unfamiliar edges DataFrame encountered!")
# Validate the nodes and raise a MissingNodePropertyException if any are missing required fields
self.validate_nodes()
# Validate the edges and raise a MissingEdgePropertyException if any are missing required fields
self.validate_edges()
# Set the S3 region
self.region = region
# Define empty mock S3 connection for optional testing
self.mock_s3: typing.Any = None
@classmethod
def string_type(cls, python_type: type) -> str:
"""Convert a Python type to its corresponding GEXF type string"""
attr_str_type = None
if python_type == str:
attr_str_type = "string"
elif python_type == int:
attr_str_type = "integer"
elif python_type == float:
attr_str_type = "float"
else:
raise InvalidAttributeTypeException(
"Invalid type! Ony supports 'string', 'integer' and 'float'"
)
return attr_str_type
def validate_nodes(self):
"""validate_nodes Verifies that all nodes have 'id' and 'label' fields
Raises
------
MissingNodePropertyException
Raised if a node is missing the 'id' or 'label' fields
"""
for node in self.local_nodes:
try:
assert "id" in node
except AssertionError:
raise MissingNodePropertyException("A node is missing the 'id' field!")
try:
assert "label" in node
except AssertionError:
raise MissingNodePropertyException(
"A node is missing the 'label' field!"
)
def validate_edges(self):
"""validate_edges Verifies that all edges have 'source' and 'target' fields
Raises
------
MissingEdgePropertyException
Raised if an edge is missing the 'source' or 'target' fields
"""
for edge in self.local_edges:
try:
assert "source" in edge
except AssertionError:
raise MissingEdgePropertyException(
"An edge is missing the 'source' field!"
)
try:
assert "target" in edge
except AssertionError:
raise MissingEdgePropertyException(
"An edge is missing the 'target' field!"
)
def set_attributes(self, attrs_type: str) -> None:
"""set_attributes Set the node/edge attributes in the graph schema. Assumes first record is representative.
Parameters
----------
attrs_type : str
nodes or edges
"""
# Get the entity attributes tag so we can add any node attributes to the graph's schema
# fixme: here and below, self.tree.find can and will return None on inputs
# where these nodes are not present. It would be safer to explicitly check for None
# and provide nice exception than just wait for it to crash. Also mypy would be happier :)
attrs_tag = self.tree.find(f"attributes[@class='{attrs_type}']")
# Create a entity attribute for all properties of the first entity
attrs = {
k: {"value": v, "type": type(v)}
for k, v in getattr(self, f"local_{attrs_type}s")[0].items()
}
# Set self.node_attrs or self.edge_attrs to the attributes we pulled from the first element
setattr(self, f"{attrs_type}_attrs", attrs)
# Don't increment the property id for standard fields, so we start at 0
i_offset = 0
# The way we generate data for XML may have made all numbers strings as XML can't have numeric tag attributes
for i, (attr_name, attr_dict) in enumerate(sorted(attrs.items())):
# Skip the id field, it doesn't need an <attribute>
if attrs_type == "node" and attr_name in ["id", "label"]:
i_offset += 1
continue
elif attrs_type == "edge" and attr_name in ["source", "target"]:
i_offset += 1
continue
attr_value = attr_dict["value"]
attr_type = attr_dict["type"]
# If it is a number then convert the type to a number
if attr_type == str and attr_value.replace(".", "", 1).isdigit():
# It is either an int or a float
if attr_value.isdigit():
attr_value = int(attr_value)
attr_type = int
else:
attr_value = float(attr_value)
attr_type = float
# Convert from i.e. str to 'string'
attr_str_type = GEXFWriter.string_type(attr_type)
# Subtract the property id offset to account for standard fields we skipped
property_id = str(i - i_offset)
# Add one of these: <attribute id="0" title="type" type="string"/>
ET.SubElement(
# type ignore here and below, because I'm not really sure if this can be None or not
attrs_tag, # type: ignore
"attribute",
{"id": property_id, "title": attr_name, "type": attr_str_type},
)
def add_elements(self) -> None:
"""add_elements add the nodes to the XML document"""
# Get the nodes tag so we can append <node> to it
nodes_tag = self.tree.find(".//nodes")
node_ids = [x["id"] for x in self.local_nodes]
# Add the nodes to the document with their properties
for i, node in enumerate(self.local_nodes):
# Append to <nodes>, using capitalized Label instead of label for Gephi to display them right
e = ET.SubElement(
nodes_tag, "node", {"id": node["id"], "Label": node["label"]} # type: ignore
)
# Assign any attributes if there are more than two keys: id and label
if len(list(node.keys())) > 2:
attvalues = ET.SubElement(e, "attvalues")
# Don't increment the property id for standard fields, so we start at 0
node_i_offset = 0
self.node_attrs: typing.Dict # set in set_attributes method
for i, (attr_name, attr_dict) in enumerate(
sorted(self.node_attrs.items())
):
# Skip the id/label fields, they are in the <node> tag and don't need an <attribute>
if attr_name in ["id", "label"]:
node_i_offset += 1
continue
# Subtract the property id offset to account for standard fields we skipped
node_property_id = str(i - node_i_offset)
# Add the value for the node's attribute
ET.SubElement(
attvalues,
"attvalue",
# All properties must be strings, the graph attribute will convert them on read
{"for": node_property_id, "value": str(node[attr_name])},
)
# Get the edges tag so we can append <node> to it
edges_tag = self.tree.find(".//edges")
# Add the edges to the document
for i, edge in enumerate(self.local_edges):
edge_dict = {
"id": str(i),
"source": edge["source"],
"target": edge["target"],
}
# Don't let any edges through without corresponding nodes - they cause headaches with missing attributes
if (
edge_dict["source"] not in node_ids
or edge_dict["target"] not in node_ids
):
raise MissingNodeException(
f"Edge from {edge_dict['source']} to {edge_dict['target']} has no corresponding node(s)!"
)
if "label" in edge:
edge_dict["Label"] = edge["label"]
# Append to <edges>
e = ET.SubElement(
edges_tag, # type: ignore
"edge",
edge_dict,
)
if len(list(edge.keys())) > 2:
attvalues = ET.SubElement(e, "attvalues")
# Don't increment the property id for standard fields, so we start at 0
edge_i_offset = 0
self.edge_attrs: typing.Dict # set in set_attributes method
for i, (attr_name, attr_dict) in enumerate(
sorted(self.edge_attrs.items())
):
# Skip the source/target fields, they are in the <edge> tag and don't need an <attribute>
if attr_name in ["source", "target"]:
edge_i_offset += 1
continue
# Subtract the property id offset to account for standard fields we skipped
edge_property_id = str(i - edge_i_offset)
# Add the value for the edge's attribute
ET.SubElement(
attvalues,
"attvalue",
# All properties must be strings, the graph attribute will convert them on read
{"for": edge_property_id, "value": str(edge[attr_name])},
)
def parse(self):
"""parse parse the pyspark.sql.DataFrame for nodes/edges into an xml.etree.ElementTree, which can write XML"""
# Get an ElemenTree and the root element
self.tree = ET.ElementTree(ET.fromstring(self.gexf_template))
self.root = self.tree.getroot()
# Get the description tag and set the description text
description_tag = self.tree.find(".//description")
description_tag.text = self.description # type: ignore
# Set the graph schema attributes for nodes and edges
self.set_attributes("node")
self.set_attributes("edge")
self.add_elements()
@classmethod
def parse_s3_url(cls, s3_url: str) -> typing.Dict[str, str]:
"""parse_s3_url parses S3 url into bucket and path
Parameters
----------
s3_url : str
A full S3 url: s3://<bucket_name>/<path>
Returns
-------
typing.Dict[str, str]
A dict of the format:
{
"bucket": <bucket_name>,
"path": <path>
}
"""
# Parse the url
url_parts = urlparse(s3_url)
# Validate the url scheme
try:
assert url_parts.scheme == "s3"
except AssertionError:
raise InvalidS3URI(
f"Invalid scheme is not s3: {url_parts.scheme}. Did you mean to use s3=False to write to local disk?"
)
# Validate the bucket name
try:
assert url_parts.netloc is not None
assert url_parts.netloc != ""
except AssertionError:
raise InvalidS3URI("Emtpy bucket name!")
# Validate the path
try:
assert url_parts.path is not None
assert url_parts.path != ""
except AssertionError:
raise InvalidS3URI("Empty path name!")
# Strip the / from the path
path = url_parts.path
if path.startswith("/"):
path = path[1:]
return {"bucket_name": url_parts.netloc, "path": path}
def export(self) -> str:
"""export Generate a valid GEXF XML string from the parsed node/edge DataFrames"""
# First convert our etree XML document to a valid XML/GEXF string...
xml_bytes = ET.tostring(self.root)
# Then add the schema information last, because etree makes things very hard if we have it earlier
xml_string = xml_bytes.decode().replace(
"<gexf>",
'<gexf xmlns="http://www.gexf.net/1.2draft" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://www.gexf.net/1.2draft http://www.gexf.net/1.2draft/gexf.xsd" version="1.2">',
)
# We have to beautify using BeautifulSoup as this is difficult in etree
clean_xml_string = BeautifulSoup(xml_string, "xml").prettify()
# Let's explicitly manage memory as this thing is often huge
del xml_bytes, xml_string
return clean_xml_string
def write_to_s3(self, s3_uri: str, mock: bool = False) -> None:
"""write_to_s3 Write a GEXF file of the node/edge list we parsed to S3
Parameters
----------
s3_url : str
S3 URI to store GEXF XML export
mock : bool, optional
whether to mock boto3 for local S3 testing without using AWS, by default False
"""
# Parse and validate the S3 url
s3_url_parts = GEXFWriter.parse_s3_url(s3_uri)
# Boto3 writes bytes, not strings
clean_xml_bytes = self.export().encode()
if mock:
self.mock_s3 = mock_s3()
self.mock_s3.start()
# Set this up here so we can mock it later if we want to
s3_client = boto3.client("s3", region_name=self.region)
if mock:
# Mock up the S3 bucket if we are testing locally
location = typing.cast(
CreateBucketConfigurationTypeDef, {"LocationConstraint": self.region}
)
s3_client.create_bucket(
Bucket=s3_url_parts["bucket_name"], CreateBucketConfiguration=location
)
# Write the file to S3
s3_client.put_object(
Body=clean_xml_bytes,
Bucket=s3_url_parts["bucket_name"],
Key=s3_url_parts["path"],
)
if mock:
self.mock_s3.stop()
def write_to_disk(self, path: str) -> None:
"""write_to_disk Write a GEXF file of the node/edge list we parsed to local disk
Parameters
----------
path : str
Local path to store GEXF XML export
"""
# Write to local disk - not possible when imported into a Databricks Notebook
with open(path, "w") as f:
clean_xml_string = self.export()
f.write(clean_xml_string)
Copyright <YEAR> <COPYRIGHT HOLDER>
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment