原创

Java-同步异步多线程专题-分组有序执行并等待结果-CompletableFuture-thenRunAsync-版本1

说明

适用于批量数据中,需要按指定字段分组,每个分组内保持原数据顺序消费。由于每个小组内需要按顺序执行,小组内只能使用单线程。整体消费的速度取决于分组数量,分组数量对应线程数,这里设置最大线程数16。线程和分组并不是一一绑定,同一个线程处理完一个分组后,会去处理下一个分组。

1. Example.java

package cn.jiangjiesheng.groupedSequentialExecutor;

import cn.jiangjiesheng.core.utils.MathUtils;
import cn.jiangjiesheng.core.utils.ZipUtil;
import lombok.Data;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;

/**
 * Java-同步异步多线程专题-分组有序执行并等待结果-CompletableFuture-thenRunAsync
 *
 * 适用于批量数据中,需要按指定字段分组,每个分组内保持原数据顺序消费。
 * 由于每个小组内需要按顺序执行,小组内只能使用单线程。整体消费的速度取决于分组数量,分组数量对应线程数,这里设置最大线程数16。
 * 线程和分组并不是一一绑定,同一个线程处理完一个分组后,会去处理下一个分组。
 *
 * 【同步多线程+分组+保持顺序+等待结果(分组+有序,所以线程数取决于分组数)】
 * 【同步】按任务列表中的某个字段分组,每个分组又要顺序执行,并最终等待所有线程拿到处理任务的结果。
 *
 * newKeyAffinityExecutor:【异步多线程】+分组+保持顺序+【不等待结果】
 *
 */
public class Example {

    public static void main(String[] args) {
        // 模拟数据
        List<TaskItem> taskItems = new ArrayList<>();
        for (int i = 0; i < 1500; i++) {
            taskItems.add(new TaskItem(i, "GROUP_ID_REAL-" + (i % 8)));
        }

        // 如果需要 复制15遍,生成1500条
        // List<TaskItem> result = new ArrayList<>(Collections.nCopies(15, taskItems))
        //         .stream()
        //         .flatMap(List::stream)
        //         .collect(Collectors.toList());
        // 如果需要
        // result = result.stream()
        //         .sorted(Comparator.comparing(TaskItem::getGroupId)
        //         .thenComparing(TaskItem::getId))
        //         .collect(Collectors.toList());

        long startTime = System.currentTimeMillis();
        // 错误收集
        List<String> errorList = Collections.synchronizedList(new ArrayList<>());
        // 在外部(任务级别)定义
        AtomicInteger logCount = new AtomicInteger(0);

        new GroupedSequentialExecutor<>(new GroupedSequentialExecutorListener<TaskItem>() {

            @Override
            public boolean isDebugMode() {
                return true;
            }
            @Override
            public List<TaskItem> getTaskItemList() {
                return taskItems;
            }
            @Override
            public Function<TaskItem, ?> getTaskGroupKeyExtractor() {
                return item -> item.groupId;
                //或 return TaskItem::getGroupId;
            }
            @Override
            public void executeItem(TaskItem item) throws Exception {
                mockTask(item.id);
            }
            @Override
            public void handleProgress(String groupIdAlias, Object groupId, TaskItem item, String threadName, Integer totalCount, Integer totalCompleteCount) {
                //可以进度redis,为了避免频发写入,可以控制共写10次(单实例可以存内存,并且设置有效期)
                //因为多线程线程之间无序,所有打印的最后一条完成数不一定等于总数,应该往前几条
                if (GroupedSequentialExecutor.shouldPrint(totalCount, totalCompleteCount, 13, logCount)) {
                    System.err.printf(
                            "当前处理进度(不一定是成功),groupIdAlias: %s, groupId: %s, id: %s, 线程: %s, 总数: %d, 完成数: %d, 完成率: %s, 耗时: %s%n",
                            groupIdAlias,
                            groupId,
                            item.id,
                            threadName,
                            totalCount,
                            totalCompleteCount,
                            MathUtils.calcPercentStr(totalCompleteCount, totalCount, 2),
                            ZipUtil.getCostTimeString(startTime)
                    );
                }
            }
            @Override
            public void handleError(String groupIdAlias, Object groupId, TaskItem item, Exception ex) {
                synchronized (errorList) {
                    errorList.add("处理失败: groupIdAlias:" + groupIdAlias + ",groupId:" + groupId + ",当前对象:" + item + ", 原因: " + ex.getMessage());
                }
            }
            @Override
            public void handleTaskFinishStats(Integer totalCount, Integer totalCompleteCount, Integer taskGroupCount, Integer nThreadCount, String costTime, String extra) {
                System.err.printf(
                        "任务处理完成,totalCount: %s, totalCompleteCount: %s, taskGroupCount: %s, nThreadCount: %s, costTime: %s, extra: %s%n",
                        totalCount,
                        totalCompleteCount,
                        taskGroupCount,
                        nThreadCount,
                        costTime,
                        extra
                );
            }
        }).execute();

        // 输出结果
        if (errorList.isEmpty()) {
            System.out.println("所有任务成功完成!");
        } else {
            System.out.println("所有任务成功完成,但存在错误:");
            errorList.forEach(System.out::println);
        }

        System.out.println("任务耗时:" + ZipUtil.getCostTimeString(startTime));
    }


    @Data
    static class TaskItem {
        Integer id;
        String groupId;

        public TaskItem(Integer id, String groupId) {
            this.id = id;
            this.groupId = groupId;
        }

        @Override
        public String toString() {
            return "TaskItem{id=" + id + ", groupId='" + groupId + "'}";
        }
    }

    private static void mockTask(Integer id) throws Exception {
        if (id % 8 == 0) {
            throw new Exception("模拟出错: id=" + id);
        }
        Thread.sleep(50);
        System.out.println("模拟处理成功: " + id + " [Thread: " + Thread.currentThread().getName() + "]");
    }

}

2. GroupedSequentialExecutorListener.java

package cn.jiangjiesheng.groupedSequentialExecutor;

import java.util.List;
import java.util.function.Function;

public interface GroupedSequentialExecutorListener<T> {

    boolean isDebugMode();

    /**
     * 必须实现
     * @return
     */
    List<T> getTaskItemList();
    /**
     * 必须实现
     * @return
     */
     Function<T, ?> getTaskGroupKeyExtractor();
    /**
     * 必须实现
     * 执行一个任务,可能抛出异常
     * 执行单个任务,必须是同步阻塞方法,事务提交后才返回。
     */
    void executeItem(T item) throws Exception;

    void handleProgress(String groupIdAlias, Object groupId, T item, String threadName, Integer totalCount, Integer totalCompleteCount);

    void handleError(String groupIdAlias, Object groupId, T item, Exception ex);

    void handleTaskFinishStats(Integer totalCount, Integer totalCompleteCount, Integer taskGroupCount, Integer nThreadCount, String costTime, String extra);

}

3. GroupedSequentialExecutor.java

package cn.jiangjiesheng.groupedSequentialExecutor;

import cn.jiangjiesheng.core.utils.ZipUtil;
import com.google.common.base.Preconditions;
import lombok.extern.slf4j.Slf4j;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

/**
 * Java-同步异步多线程专题-分组有序执行并等待结果-CompletableFuture-thenRunAsync
 *
 * 适用于批量数据中,需要按指定字段分组,每个分组内保持原数据顺序消费。
 * 由于每个小组内需要按顺序执行,小组内只能使用单线程。整体消费的速度取决于分组数量,分组数量对应线程数,这里设置最大线程数16。
 * 线程和分组并不是一一绑定,同一个线程处理完一个分组后,会去处理下一个分组。
 *
 * 【同步多线程+分组+保持顺序+等待结果(分组+有序,所以线程数取决于分组数)】
 * 【同步】按任务列表中的某个字段分组,每个分组又要顺序执行,并最终等待所有线程拿到处理任务的结果。
 *
 * newKeyAffinityExecutor:【异步多线程】+分组+保持顺序+【不等待结果】
 *
 */
@Slf4j
public class GroupedSequentialExecutor<T> {
    //自定义线程池(可选)  线程数会影响处理时间
    //普通开发机(4核)    8 ~ 16    4 × 2 ~ 4 × 4
    //生产服务器(8核)    16 ~ 32    8 × 2 ~ 8 × 4
    //高并发优化    32 ~ 64    需测试,防止线程过多导致上下文切换开销
    //保守稳妥    CPU核心数 × 2    平衡性能和资源

    //实际使用还是取决于任务分组数的情况
    private final Integer MAX_THREAD_COUNT = 16;

    // 进度计数器
    private final AtomicInteger completedCount = new AtomicInteger(0);
    // 总任务数
    private int totalTasks = 0;
    private ExecutorService executorService;
    private final List<T> taskItemList;
    private final boolean debugMode;
    private final boolean shouldShutdownExecutor;

    private final GroupedSequentialExecutorListener<T> groupedSequentialExecutorListener;

    public GroupedSequentialExecutor(GroupedSequentialExecutorListener<T> groupedSequentialExecutorListener) {
        this.groupedSequentialExecutorListener = groupedSequentialExecutorListener;
        this.taskItemList = this.groupedSequentialExecutorListener.getTaskItemList();
        this.debugMode = this.groupedSequentialExecutorListener.isDebugMode();
        this.shouldShutdownExecutor = true;
        check();
    }

    public void execute() {
        long startTime = System.currentTimeMillis();
        int taskGroupCount = 0;
        int nThreads= 0;
        try {
            //结合处理进度的逻辑,业务调用前,就应该先判断,这里只是兜底。
            if (taskItemList.isEmpty()) {
                return;
            }

            // 记录总任务数(用于计算进度)
            totalTasks = taskItemList.size();
            completedCount.set(0);

            // 分组
            Map<Object, List<T>> grouped = taskItemList.stream()
                    .collect(Collectors.groupingBy(this.groupedSequentialExecutorListener.getTaskGroupKeyExtractor()));
            taskGroupCount = grouped.keySet().size();
             nThreads = Math.min(taskGroupCount, MAX_THREAD_COUNT);
            //因为分组还要按顺序
            this.executorService = Executors.newFixedThreadPool(nThreads, new ThreadFactory() {
                private final AtomicInteger threadNum = new AtomicInteger(1);

                @Override
                public Thread newThread(Runnable r) {
                    Thread t = new Thread(r);
                    t.setName("grouped-sequential-executor-" + threadNum.getAndIncrement());
                    return t;
                }
            });

            //分组数据,可以不打印,节省时间
            if (this.debugMode) {
                log.info("共有分组数: {},各分组任务数: {}",
                        taskGroupCount,
                        grouped.entrySet().stream()
                                .collect(Collectors.toMap(
                                        Map.Entry::getKey,
                                        e -> e.getValue().size()
                                )));
            }

            //这个一开始会利用到更多的线程
            //handleOne(grouped);

            //这个始终按输入数据的分组数设置线程的,设置了再多也不行
            handleTwo(grouped);

        } finally {
            //保持状态正确
            this.groupedSequentialExecutorListener.handleTaskFinishStats(totalTasks, completedCount.get(), taskGroupCount, nThreads, ZipUtil.getCostTimeString(startTime), "正常结束");

            totalTasks = 0;
            completedCount.set(0);

            // 如果使用了自定义线程池,不关闭;否则如果是内部创建的 commonPool,无需处理 【如果调用的不平凡,应该都可以关闭】
            // 注意:如果是传入的 customPool,这里主动关闭了,如果确实需要设置可以不关闭,那就加个构造器入参,一起判断修改shouldShutdownExecutor的值
            if (this.shouldShutdownExecutor && this.executorService != null) {
                //外层重复关闭也没关系,不报错且幂等
                this.executorService.shutdown();
                try {
                    if (!executorService.awaitTermination(30, TimeUnit.SECONDS)) {
                        executorService.shutdownNow();
                    }
                } catch (InterruptedException e) {
                    executorService.shutdownNow();
                    Thread.currentThread().interrupt();
                }
            }
        }
    }

    /**
     * 控制最多打印次数
     *
     * @param totalCount
     * @param totalCompleteCount
     * @param maxTimes
     * @param logCount 计数器,放在调用外部
     * @return
     */
    public static boolean shouldPrint(Integer totalCount, Integer totalCompleteCount, Integer maxTimes, AtomicInteger logCount) {
        // 防护:totalCount <= 0
        if (totalCount == null || totalCount <= 0 || totalCompleteCount == null) {
            return false;
        }

        // 计算每多少条打印一次(向上取整,保证最多 maxTimes 次)
        // 向上取整
        int idealInterval = (totalCount + maxTimes - 1) / maxTimes;
        idealInterval = Math.max(1, idealInterval);

        // 判断是否应该触发打印
        boolean shouldLog =
                totalCompleteCount == 0 ||  // 0%
                        totalCompleteCount.equals(totalCount) ||  // 100%
                        (totalCompleteCount % idealInterval == 0);  // 中间等间隔

        // 关键:用计数器保证最多打印 maxTimes 次
        return shouldLog && logCount.incrementAndGet() <= maxTimes;
    }


    /**
     * 逻辑1 这个一开始会利用到更多的线程
     * @param grouped
     */
    private void handleOne(Map<Object, List<T>> grouped) {
        List<CompletableFuture<Void>> groupFutures = new ArrayList<>();
        final ConcurrentMap<Object, String> groupIdToAlias = new ConcurrentHashMap<>();
        final AtomicInteger aliasCounter = new AtomicInteger(1); // 从 1 开始
        for (Map.Entry<Object, List<T>> entry : grouped.entrySet()) {
            Object groupId = entry.getKey();      // 分组 ID
            List<T> groupItems = entry.getValue(); // 该组的所有任务
            String groupIdAlias = getGroupIdAlias(groupId, groupIdToAlias, aliasCounter);

            CompletableFuture<Void> chain = CompletableFuture.completedFuture(null);

            for (T item : groupItems) {
                final T currentItem = item;
                //串行(对于同一个任务处理组,虽然有多个线程,但是处理的任务还是按顺序进行的,这里会使用到都有executorService现场数)
                chain = chain.thenRunAsync(() -> {
                    try {
                        if (this.debugMode) {
                            log.info("当前任务处理组: {},线程: {},任务明细: {}",
                                    groupId,
                                    Thread.currentThread().getName(),
                                    currentItem
                            );
                        }
                        groupedSequentialExecutorListener.executeItem(currentItem);
                    } catch (Exception e) {
                        if (groupedSequentialExecutorListener != null) {
                            groupedSequentialExecutorListener.handleError(groupIdAlias, groupId, currentItem, e);
                        } else {
                            System.err.printf("当前任务处理出错: %s, reason: %s%n",
                                    currentItem,
                                    e.getMessage()
                            );
                        }
                    } finally {
                        //实时返回完成进度(不代表成功,只表示完成)
                        int currentProgress = completedCount.incrementAndGet();
                        if (groupedSequentialExecutorListener != null) {
                            groupedSequentialExecutorListener.handleProgress(groupIdAlias, groupId, currentItem, Thread.currentThread().getName(), totalTasks, currentProgress);
                        }
                    }
                }, executorService);
            }

            groupFutures.add(chain);
        }

        // 等待所有组完成
        CompletableFuture.allOf(groupFutures.toArray(new CompletableFuture[0])).join();
    }

    /**
     * 这个始终按输入数据的分组数设置线程的,设置了再多也不行
     * @param grouped
     */
    private void handleTwo(Map<Object, List<T>> grouped) {
        // 为每个组创建一个单线程的任务队列(模拟串行)
        List<CompletableFuture<Void>> groupFutures = new ArrayList<>();

        final ConcurrentMap<Object, String> groupIdToAlias = new ConcurrentHashMap<>();
        final AtomicInteger aliasCounter = new AtomicInteger(1); // 从 1 开始
        for (Map.Entry<Object, List<T>> entry : grouped.entrySet()) {
            Object groupId = entry.getKey();
            List<T> groupItems = entry.getValue();
            String groupIdAlias = getGroupIdAlias(groupId, groupIdToAlias, aliasCounter);

            // 使用一个单线程的“执行器”来保证该组任务的顺序
            // TODO 还有事务提交的问题 ,CompletableFuture.runAsync 还要优化,如果有事务,还要等同一个groupId的上个事务提交了
            // AI回复 没有问题
            // 确认下,这里提交事务是不是同步的,如果是,理论上也不问题
            CompletableFuture<Void> groupFuture = CompletableFuture.runAsync(() -> {
                for (T item : groupItems) {
                    try {
                        groupedSequentialExecutorListener.executeItem(item);
                    } catch (Exception e) {
                        if (groupedSequentialExecutorListener != null) {
                            groupedSequentialExecutorListener.handleError(groupIdAlias, groupId, item, e);
                        } else {
                            System.err.printf("任务处理出错: %s, 错误: %s%n", item, e.getMessage());
                        }
                    } finally {
                        int currentProgress = completedCount.incrementAndGet();
                        if (groupedSequentialExecutorListener != null) {
                            groupedSequentialExecutorListener.handleProgress(groupIdAlias, groupId, item, Thread.currentThread().getName(), totalTasks, currentProgress);
                        }
                    }
                }
            }, executorService);

            groupFutures.add(groupFuture);
        }

        // 等待所有组完成
        CompletableFuture.allOf(groupFutures.toArray(new CompletableFuture[0])).join();

        //// 清理状态 统一到上层
        //if (shouldShutdownExecutor && executorService != null) {
        //    executorService.shutdown();
        //}
        //totalTasks = 0;
        //completedCount.set(0);
    }


    private String getGroupIdAlias(Object groupId, ConcurrentMap<Object, String> groupIdToAlias, AtomicInteger aliasCounter) {
        return groupIdToAlias.computeIfAbsent(groupId, k -> "groupAlias-" + aliasCounter.getAndIncrement());
    }

private void check(){
        Preconditions.checkArgument(this.groupedSequentialExecutorListener !=null,"未实现GroupedSequentialExecutorListener接口");
        Preconditions.checkArgument(this.groupedSequentialExecutorListener.getTaskGroupKeyExtractor() != null,"未实现getTaskGroupKeyExtractor方法");
    }
}
正文到此结束
本文目录