Ver Fonte

Merge remote-tracking branch 'origin/master'

xufei há 3 meses atrás
pai
commit
19e7820225

+ 54 - 0
src/main/java/com/winhc/max/compute/graph/util/GeneralAggregator.java

@@ -0,0 +1,54 @@
+package com.winhc.max.compute.graph.util;
+
+import com.aliyun.odps.graph.Aggregator;
+import com.aliyun.odps.graph.WorkerContext;
+import com.aliyun.odps.io.BooleanWritable;
+import com.aliyun.odps.io.NullWritable;
+import com.aliyun.odps.io.Writable;
+
+import java.io.IOException;
+
+/**
+ * @Author: XuJiakai
+ * 2023/5/16 10:30
+ */
+public class GeneralAggregator extends Aggregator<GeneralAggregatorValue> {
+
+    @Override
+    public GeneralAggregatorValue createInitialValue(WorkerContext context) throws IOException {
+        return new GeneralAggregatorValue();
+    }
+
+    @Override
+    public void aggregate(GeneralAggregatorValue generalAggregatorValue, Object item) throws IOException {
+        if (item instanceof NullWritable) {
+            generalAggregatorValue.increment();
+        } else {
+            BooleanWritable tmp = ((BooleanWritable) item);
+            generalAggregatorValue.update(tmp.get());
+        }
+    }
+
+    @Override
+    public void merge(GeneralAggregatorValue generalAggregatorValue, GeneralAggregatorValue partial) throws IOException {
+        boolean tmp = partial.getFlag().get();
+        generalAggregatorValue.update(tmp);
+        generalAggregatorValue.increment(partial.getCount());
+    }
+
+    @Override
+    public boolean terminate(WorkerContext context, GeneralAggregatorValue generalAggregatorValue) throws IOException {
+        System.out.println("step: " + context.getSuperstep() + ", iterations: " + generalAggregatorValue.getCount() + " ,flag=" + generalAggregatorValue.getFlag().get());
+        Writable lastAggregatedValue = context.getLastAggregatedValue();
+        if (lastAggregatedValue != null) {
+            GeneralAggregatorValue lastAgg = ((GeneralAggregatorValue) lastAggregatedValue);
+            long count = lastAgg.getCount();
+            if (count > 10) {
+                System.out.println("\tmanual operation...");
+                return false;
+            }
+        }
+        return !generalAggregatorValue.getFlag().get();
+    }
+
+}

+ 78 - 0
src/main/java/com/winhc/max/compute/graph/util/GeneralAggregatorValue.java

@@ -0,0 +1,78 @@
+package com.winhc.max.compute.graph.util;
+
+import com.aliyun.odps.io.BooleanWritable;
+import com.aliyun.odps.io.LongWritable;
+import com.aliyun.odps.io.Writable;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * @Author: XuJiakai
+ * 2023/5/16 10:31
+ */
+public class GeneralAggregatorValue implements Writable {
+
+    /**
+     * 为false时表示迭代完毕,可以结束
+     */
+    private BooleanWritable flag;
+
+    /**
+     * 本轮迭代顶点数
+     */
+    private LongWritable num;
+
+    public BooleanWritable getFlag() {
+        return flag;
+    }
+
+    public long getCount() {
+        return this.num.get();
+    }
+
+    public void increment() {
+        increment(1);
+    }
+
+    public void increment(long num) {
+        this.num.set(this.num.get() + num);
+    }
+
+    public void update(boolean flag) {
+        this.flag.set(this.flag.get() || flag);
+    }
+
+
+    public GeneralAggregatorValue() {
+        flag = new BooleanWritable(false);
+        this.num = new LongWritable(0);
+    }
+
+    @Override
+    public String toString() {
+        return "EnterpriseGroupAggValue{" +
+                "flag=" + flag +
+                ", num=" + num +
+                '}';
+    }
+
+    @Override
+    public void write(DataOutput out) throws IOException {
+        flag.write(out);
+        num.write(out);
+    }
+
+    @Override
+    public void readFields(DataInput in) throws IOException {
+        if (flag == null) {
+            flag = new BooleanWritable();
+        }
+        if (num == null) {
+            num = new LongWritable();
+        }
+        flag.readFields(in);
+        num.readFields(in);
+    }
+}