55from graphblas_algorithms.classes.graph import to_undirected_graph
66from graphblas_algorithms.utils import not_implemented_for
77
8+ from ._utils import normalize_chunksize, partition
9+
810__all__ = [
911 "triangles",
1012 "transitivity",
@@ -90,11 +92,11 @@ def _split(L, k):
9092
9193
9294# TODO: should this move into algorithms?
93- def _square_clustering_split(G, node_ids=None, *, nsplits ):
95+ def _square_clustering_split(G, node_ids=None, *, chunksize ):
9496 if node_ids is None:
9597 node_ids, _ = G._A.reduce_rowwise(monoid.any).to_coo(values=False)
9698 result = None
97- for chunk_ids in _split(node_ids, nsplits ):
99+ for chunk_ids in partition(chunksize, node_ids ):
98100 res = algorithms.square_clustering(G, chunk_ids)
99101 if result is None:
100102 result = res
@@ -103,36 +105,32 @@ def _square_clustering_split(G, node_ids=None, *, nsplits):
103105 return result
104106
105107
106- def square_clustering(G, nodes=None, *, nsplits="auto "):
107- # `nsplits ` is used to split the computation into chunks.
108+ def square_clustering(G, nodes=None, *, chunksize="256 MiB "):
109+ # `chunksize ` is used to split the computation into chunks.
108110 # square_clustering computes `A @ A`, which can get very large, even dense.
109- # The default `nsplits ` is to choose the number so that `Asubset @ A`
111+ # The default `chunksize ` is to choose the number so that `Asubset @ A`
110112 # will be about 256 MB if dense.
111113 G = to_undirected_graph(G)
112114 if len(G) == 0:
113115 return {}
114- if nsplits == "auto":
115- # TODO: make a utility function for this that can be reused
116- # Also, should we use `chunksize` instead of `nsplits`?
117- targetsize = 256 * 1024 * 1024 # 256 MB
118- nsplits = len(G) ** 2 * G._A.dtype.np_type.itemsize // targetsize
119- if nsplits <= 1:
120- nsplits = None
116+
117+ chunksize = normalize_chunksize(chunksize, len(G) * G._A.dtype.np_type.itemsize, len(G))
118+
121119 if nodes is None:
122120 # Should we use this one for subsets of nodes as well?
123- if nsplits is None:
121+ if chunksize is None:
124122 result = algorithms.square_clustering(G)
125123 else:
126- result = _square_clustering_split(G, nsplits=nsplits )
124+ result = _square_clustering_split(G, chunksize=chunksize )
127125 return G.vector_to_nodemap(result, fill_value=0)
128126 if nodes in G:
129127 idx = G._key_to_id[nodes]
130128 return algorithms.single_square_clustering(G, idx)
131129 ids = G.list_to_ids(nodes)
132- if nsplits is None:
130+ if chunksize is None:
133131 result = algorithms.square_clustering(G, ids)
134132 else:
135- result = _square_clustering_split(G, ids, nsplits=nsplits )
133+ result = _square_clustering_split(G, ids, chunksize=chunksize )
136134 return G.vector_to_nodemap(result)
137135
138136
0 commit comments