|
@@ -0,0 +1,54 @@
|
|
|
+package com.winhc.max.compute.graph.util;
|
|
|
+
|
|
|
+import com.aliyun.odps.graph.Aggregator;
|
|
|
+import com.aliyun.odps.graph.WorkerContext;
|
|
|
+import com.aliyun.odps.io.BooleanWritable;
|
|
|
+import com.aliyun.odps.io.NullWritable;
|
|
|
+import com.aliyun.odps.io.Writable;
|
|
|
+
|
|
|
+import java.io.IOException;
|
|
|
+
|
|
|
+/**
|
|
|
+ * @Author: XuJiakai
|
|
|
+ * 2023/5/16 10:30
|
|
|
+ */
|
|
|
+public class GeneralAggregator extends Aggregator<GeneralAggregatorValue> {
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public GeneralAggregatorValue createInitialValue(WorkerContext context) throws IOException {
|
|
|
+ return new GeneralAggregatorValue();
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void aggregate(GeneralAggregatorValue generalAggregatorValue, Object item) throws IOException {
|
|
|
+ if (item instanceof NullWritable) {
|
|
|
+ generalAggregatorValue.increment();
|
|
|
+ } else {
|
|
|
+ BooleanWritable tmp = ((BooleanWritable) item);
|
|
|
+ generalAggregatorValue.update(tmp.get());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void merge(GeneralAggregatorValue generalAggregatorValue, GeneralAggregatorValue partial) throws IOException {
|
|
|
+ boolean tmp = partial.getFlag().get();
|
|
|
+ generalAggregatorValue.update(tmp);
|
|
|
+ generalAggregatorValue.increment(partial.getCount());
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public boolean terminate(WorkerContext context, GeneralAggregatorValue generalAggregatorValue) throws IOException {
|
|
|
+ System.out.println("step: " + context.getSuperstep() + ", iterations: " + generalAggregatorValue.getCount() + " ,flag=" + generalAggregatorValue.getFlag().get());
|
|
|
+ Writable lastAggregatedValue = context.getLastAggregatedValue();
|
|
|
+ if (lastAggregatedValue != null) {
|
|
|
+ GeneralAggregatorValue lastAgg = ((GeneralAggregatorValue) lastAggregatedValue);
|
|
|
+ long count = lastAgg.getCount();
|
|
|
+ if (count > 10) {
|
|
|
+ System.out.println("\tmanual operation...");
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return !generalAggregatorValue.getFlag().get();
|
|
|
+ }
|
|
|
+
|
|
|
+}
|