From aa2ea0452bb8bc63f53d6c1e548d2ae8176e72b4 Mon Sep 17 00:00:00 2001 From: Wey Gu Date: Wed, 1 Mar 2023 01:45:18 +0000 Subject: [PATCH] feat: get all algo function --- docs/API.md | 1 + ngdi/nebula_algo.py | 61 ++++++++++++++++++++++++++++++++++++++------- 2 files changed, 53 insertions(+), 9 deletions(-) diff --git a/docs/API.md b/docs/API.md index f93c9de..b361b61 100644 --- a/docs/API.md +++ b/docs/API.md @@ -53,6 +53,7 @@ ngdi.`NebulaDataFrameObject` is a Spark DataFrame or Pandas DataFrame, which can ### Functions +- `ngdi.NebulaGraphObject.algo.get_all_algo()` returns all algorithms supported by the engine. - `ngdi.NebulaGraphObject.algo.pagerank()` runs the PageRank algorithm on the NetworkX Graph. not yet implemented. ## NebulaAlgorithm diff --git a/ngdi/nebula_algo.py b/ngdi/nebula_algo.py index d642988..b742a13 100644 --- a/ngdi/nebula_algo.py +++ b/ngdi/nebula_algo.py @@ -7,6 +7,11 @@ from ngdi.nebula_data import NebulaDataFrameObject as NebulaDataFrameObjectImpl +def algo(func): + func.is_algo = True + return func + + class NebulaAlgorithm: def __init__(self, obj: NebulaGraphObjectImpl or NebulaDataFrameObjectImpl): if isinstance(obj, NebulaGraphObjectImpl): @@ -33,6 +38,17 @@ class NebulaDataFrameAlgorithm: def __init__(self, ndf_obj: NebulaDataFrameObjectImpl): self.ndf_obj = ndf_obj + self.algorithms = [] + + def register_algo(self, func): + self.algorithms.append(func.__name__) + + def get_all_algo(self): + if not self.algorithms: + for name, func in NebulaDataFrameAlgorithm.__dict__.items(): + if hasattr(func, "is_algo"): + self.register_algo(func) + return self.algorithms def check_engine(self): """ @@ -73,6 +89,7 @@ def get_spark_dataframe(self): ) return df + @algo def pagerank( self, reset_prob: float = 0.15, max_iter: int = 10, weighted: bool = False ): @@ -85,6 +102,7 @@ def pagerank( return result + @algo def connected_components(self, max_iter: int = 10, weighted: bool = False): engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context( "CcConfig", "ConnectedComponentsAlgo" @@ -97,6 +115,7 @@ def connected_components(self, max_iter: int = 10, weighted: bool = False): return result + @algo def label_propagation(self, max_iter: int = 10, weighted: bool = False): engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context( "LPAConfig", "LabelPropagationAlgo" @@ -110,6 +129,7 @@ def label_propagation(self, max_iter: int = 10, weighted: bool = False): return result + @algo def louvain(self, max_iter: int = 10, internalIter: int = 10, tol: float = 0.0001): engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context( "LouvainConfig", "LouvainAlgo" @@ -121,6 +141,7 @@ def louvain(self, max_iter: int = 10, internalIter: int = 10, tol: float = 0.000 return result + @algo def k_core(self, max_iter: int = 10, degree: int = 2): engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context( "KCoreConfig", "KCoreAlgo" @@ -145,6 +166,7 @@ def k_core(self, max_iter: int = 10, degree: int = 2): # return result + @algo def degree_statics(self): engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context( "DegreeStaticConfig", "DegreeStaticAlgo" @@ -156,6 +178,7 @@ def degree_statics(self): return result + @algo def betweenness_centrality( self, max_iter: int = 10, degree: int = 2, weighted: bool = False ): @@ -171,6 +194,7 @@ def betweenness_centrality( return result + @algo def coefficient_centrality(self, type: str = "local"): # type could be either "local" or "global" assert type.lower() in ["local", "global"], ( @@ -187,6 +211,7 @@ def coefficient_centrality(self, type: str = "local"): return result + @algo def bfs(self, max_depth: int = 10, root: int = 1): engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context( "BfsConfig", "BfsAlgo" @@ -199,18 +224,19 @@ def bfs(self, max_depth: int = 10, root: int = 1): return result # dfs is not yet supported, need to revisit upstream nebula-algorithm - # - # def dfs(self, max_depth: int = 10, root: int = 1): - # engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context( - # "DfsConfig", "DfsAlgo" - # ) - # df = self.get_spark_dataframe() + @algo + def dfs(self, max_depth: int = 10, root: int = 1): + engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context( + "DfsConfig", "DfsAlgo" + ) + df = self.get_spark_dataframe() - # config = spark._jvm.DfsConfig(max_depth, root, encode_vertex_id) - # result = spark._jvm.DfsAlgo.apply(jspark, df._jdf, config) + config = spark._jvm.DfsConfig(max_depth, root, encode_vertex_id) + result = spark._jvm.DfsAlgo.apply(jspark, df._jdf, config) - # return result + return result + @algo def hanp( self, hop_attenuation: float = 0.5, @@ -233,6 +259,7 @@ def hanp( return result + # @algo # def node2vec( # self, # max_iter: int = 10, @@ -277,6 +304,7 @@ def hanp( # return result + @algo def jaccard(self, tol: float = 1.0): engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context( "JaccardConfig", "JaccardAlgo" @@ -288,6 +316,7 @@ def jaccard(self, tol: float = 1.0): return result + @algo def strong_connected_components(self, max_iter: int = 10, weighted: bool = False): engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context( "CcConfig", "StronglyConnectedComponentsAlgo" @@ -300,6 +329,7 @@ def strong_connected_components(self, max_iter: int = 10, weighted: bool = False return result + @algo def triangle_count(self): engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context( "TriangleConfig", "TriangleCountAlgo" @@ -310,6 +340,7 @@ def triangle_count(self): return result + # @algo # def closeness(self, weighted: bool = False): # # TBD: ClosenessAlgo is not yet encodeID compatible # engine, spark, jspark, encode_vertex_id = self.get_spark_engine_context( @@ -329,6 +360,17 @@ class NebulaGraphAlgorithm: def __init__(self, graph): self.graph = graph + self.algorithms = [] + + def register_algo(self, func): + self.algorithms.append(func.__name__) + + def get_all_algo(self): + if not self.algorithms: + for name, func in NebulaGraphAlgorithm.__dict__.items(): + if hasattr(func, "is_algo"): + self.register_algo(func) + return self.algorithms def check_engine(self): """ @@ -343,6 +385,7 @@ def check_engine(self): "For example: df = nebula_graph.to_df; df.algo.pagerank()", ) + @algo def pagerank(self, reset_prob=0.15, max_iter=10): self.check_engine() pass