原创

Mybatis超时设置和异常重试拦截器

package cn.jiangjiesheng.config.mybatis;

import cn.jiangjiesheng.common.GnStatic;
import cn.jiangjiesheng.core.utils.StringUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.session.ResultHandler;
import org.mybatis.spring.MyBatisSystemException;
import org.postgresql.util.PSQLException;
import org.springframework.dao.DataAccessResourceFailureException;
import org.springframework.jdbc.datasource.ConnectionHolder;
import org.springframework.transaction.support.TransactionSynchronizationManager;
import org.apache.commons.lang3.exception.ExceptionUtils;

import java.io.EOFException;
import java.lang.reflect.InvocationTargetException;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;


/**
 * 数据库查询超时判断 和  异常重试 拦截器
 *
 * -- 模拟更新出现30%概率的异常
 * CREATE OR REPLACE FUNCTION random_failure()
 * RETURNS trigger AS $$
 * BEGIN
 *     -- 以一定概率(例如 30%)抛出异常
 *     IF random() < 0.3 THEN
 *         RAISE EXCEPTION 'Random failure occurred';
 *     END IF;
 *     RETURN NEW;
 * END;
 * $$ LANGUAGE plpgsql;
 *
 * -- 创建触发器,在插入数据时可能抛出异常
 * CREATE TRIGGER maybe_fail_trigger
 * BEFORE UPDATE ON system_area_info
 * FOR EACH ROW EXECUTE PROCEDURE random_failure();
 *
 * SELECT * from system_area_info
 *
 * -- 删除触发器
 * DROP TRIGGER IF EXISTS maybe_fail_trigger ON system_area_info;
 * 
 * SQL测试:
 * UPDATE system_area_info SET update_time = now() WHERE id = 330784;
 * 代码测试:
 * SystemAreaInfoDO update = new SystemAreaInfoDO();
 * update.setId(330784L);
 * update.setUpdateTime(new Date());
 * systemAreaMapper.updateById(update);
 */
@Intercepts({
        @Signature(type = StatementHandler.class, method = "update", args = {Statement.class}),
        @Signature(type = StatementHandler.class, method = "batch", args = {Statement.class}),
        @Signature(type = StatementHandler.class, method = "query", args = {Statement.class, ResultHandler.class})})
@Slf4j
public class MybatisTransactionTimeoutInterceptor implements Interceptor {

    /**
     * 最大重试次数
     * pgsqlMaxRetryTimes外部入参 或者 PGSQL_MAX_RETRY_TIMES默认值 设置为0时不重试
     */
    private Integer PGSQL_MAX_RETRY_TIMES = 2;
    /**
     * 出现异常后初始延迟毫秒
     */
    private Integer PGSQL_INITIAL_DELAY_MS = 50;
    /**
     * 设置最大深度
     */
    private static final int MAX_CAUSE_DEPTH = 5;

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        return executeWithRetry(invocation, 0);
    }

    @Override
    public Object plugin(Object target) {
        if (target instanceof StatementHandler) {
            return Plugin.wrap(target, this);
        } else {
            return target;
        }
    }

    @Override
    public void setProperties(Properties properties) {
        // 可以设置一些参数,例如最大重试次数等
        String maxRetriesStr = properties.getProperty("pgsqlMaxRetryTimes");
        if (StringUtils.isNotBlank(maxRetriesStr)) {
            PGSQL_MAX_RETRY_TIMES = Integer.parseInt(maxRetriesStr);
        }
    }

    /**
     * 带重试的执行
     * PGSQL_MAX_RETRY_TIMES=0则不重试
     *
     * @param invocation
     * @param retryCount
     * @return
     * @throws Throwable
     */
    private Object executeWithRetry(Invocation invocation, int retryCount) throws Throwable {
        Statement stmt = (Statement) invocation.getArgs()[0];
        adjustQueryTimeout(stmt, retryCount);
        try {
            Object proceed = invocation.proceed();
            if (retryCount > 0) {
                //记录到error级别,方便查看分析
                log.error("sql执行出现异常,重试执行成功,traceId:{},当前为第{}次重试的结果", GnStatic.getTraceId(), retryCount);
            }
            return proceed;
        } catch (Throwable e) {
            if (PGSQL_MAX_RETRY_TIMES <= 0) {
                throw e;
            }
            boolean retryableException;
            try {
                retryableException = isRetryableException(e);
            } catch (Exception ignore) {
                //出现未知的错误的,都不执行重试
                throw e;
            }
            if (retryableException && retryCount < PGSQL_MAX_RETRY_TIMES) {
                retryCount++;
                long delay = calculateDelayMillisecond(retryCount);
                if (retryCount == 1) {
                    StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
                    BoundSql boundSql = statementHandler.getBoundSql();
                    String sql = boundSql.getSql().replaceAll("[\\s]+", " ").trim();;
                    log.error("sql执行出现异常,正在尝试重试,traceId:{},当前将进行第{}次重试,sql:{},异常:", GnStatic.getTraceId(), retryCount, sql, e);
                } else {
                    log.error("sql执行出现异常,正在尝试重试,traceId:{},当前将进行第{}次重试", GnStatic.getTraceId(), retryCount);
                }
                // 使用线程休眠来实现延迟,也可以考虑使用非阻塞方法
                Thread.sleep(delay);

                // 递归调用自身进行重试
                return executeWithRetry(invocation, retryCount);
            } else {
                String message;
                try {
                    message = getMessageFromExceptionV1(e);
                } catch (Exception ignore) {
                    //出现未知的错误的,都不执行重试
                    throw e;
                }
                if (retryCount > 0) {
                    log.error("sql执行出现异常,结束重试,traceId:{},当前已重试次数:{},异常:{}", GnStatic.getTraceId(), retryCount, message);
                } else {
                    log.error("sql执行出现异常,不支持重试的异常,traceId:{},异常:{}", GnStatic.getTraceId(), message);
                }
                throw e;
            }
        }
    }

    /**
     * 动态设置超时时间
     *
     * @param stmt
     * @param retryCount
     * @throws SQLException
     */
    private void adjustQueryTimeout(Statement stmt, int retryCount) throws SQLException {
        Collection<Object> values = TransactionSynchronizationManager.getResourceMap().values();
        if (!values.isEmpty()) {
            for (Object obj : values) {
                if (obj instanceof ConnectionHolder) {
                    ConnectionHolder holder = (ConnectionHolder) obj;
                    if (holder.hasTimeout()) {
                        int baseQueryTimeOut = holder.getTimeToLiveInSeconds();
                        int queryTimeOut = Math.min(baseQueryTimeOut, calculateTimeoutSecond(retryCount));
                        stmt.setQueryTimeout(queryTimeOut);
                    }
                    break;
                }
            }
        }
    }

    /**
     * 动态计算超时时间
     *
     * @param retryCount
     * @return
     */
    private int calculateTimeoutSecond(int retryCount) {
        // 指数退避算法,可以根据重试次数调整查询超时时间
        // 转换为秒
        return (int) (calculateDelayMillisecond(retryCount) / 1000);
    }

    /**
     * 计算延迟查询的事件
     *
     * @param attempt
     * @return
     */
    private long calculateDelayMillisecond(int attempt) {
        // 指数退避算法,可以根据SQL命令类型调整初始延迟时间和增长因子
        // 初始延迟为50ms,每次翻倍
        return (long) Math.pow(2, attempt - 1) * PGSQL_INITIAL_DELAY_MS;
    }

    /**
     * 当前是否是需要重试的异常
     * @param e
     * @return
     */
    private boolean isRetryableException(Throwable e) {
        if (e == null) {
            return false;
        }
        // 如果异常是InvocationTargetException,则获取其目标异常进行处理
        if (e instanceof InvocationTargetException) {
            e = ((InvocationTargetException) e).getTargetException();
        }
        // 获取异常链中最底层的原因异常
        Throwable rootCause = ExceptionUtils.getRootCause(e);
        if (rootCause == null) {
            rootCause = e;
        }
        // 检查最底层的原因异常是否为特定的可重试异常类型之一
        if (PostgreSqlExceptionInfo.checkIsRetryableException(rootCause)) {
            return true;
        }

        // 检查当前异常或其原因是否为特定的可重试异常类型之一
        // 防止无限循环的同时处理递归异常和循环引用(自身)
        // 设置最大遍历深度
        int depth = 0;
        Set<Throwable> seen = new HashSet<>();
        while (e != null && depth < MAX_CAUSE_DEPTH && !seen.contains(e)) {
            if (PostgreSqlExceptionInfo.checkIsRetryableException(e)) {
                return true;
            }
            seen.add(e);
            depth++;
            // 继续检查异常的原因
            e = e.getCause();
        }
        return false;
    }

    /**
     * 解析异常msg 方法1
     *
     * @param e
     * @return
     */
    private static String getMessageFromExceptionV1(Throwable e) {
        if (e == null) {
            return "Unknown exception";
        }

        StringBuilder messageBuilder = new StringBuilder();
        // 尝试获取异常链中最深层的非空消息
        Throwable deepestCause = ExceptionUtils.getRootCause(e);
        if (deepestCause == null) {
            deepestCause = e;
        }
        // 添加异常类名和消息
        messageBuilder.append(deepestCause.getClass().getSimpleName())
                .append(": ")
                .append(StringUtils.defaultIfBlank(deepestCause.getMessage(), deepestCause.toString()));

        // 使用集合记录已访问的异常,防止循环引用
        Set<Throwable> seen = new HashSet<>();
        seen.add(deepestCause);

        // 如果有其他原因,递归添加
        Throwable cause = deepestCause.getCause();
        int depth = 0;
        while (cause != null && depth < MAX_CAUSE_DEPTH && !seen.contains(cause)) {
            messageBuilder.append(" Caused by: ")
                    .append(cause.getClass().getSimpleName())
                    .append(": ")
                    .append(StringUtils.defaultIfBlank(cause.getMessage(), cause.toString()));
            seen.add(cause);
            cause = cause.getCause();
            depth++;
        }

        return messageBuilder.toString();
    }

    /**
     * 解析异常msg 方法2
     *
     * @param e
     * @return
     */
    private static String getMessageFromExceptionV2(Throwable e) {
        if (e == null) {
            return "Unknown exception";
        }
        StringBuilder sb = new StringBuilder();
        sb.append(StringUtils.defaultString(e.getMessage()));
        if (e.getCause() != null) {
            sb.append(StringUtils.defaultString(e.getCause().getMessage()));
        }
        //e.toString()的结果
        sb.append(e);
        return sb.toString();
    }

    private static class PostgreSqlExceptionInfo {
        /**
         * 异常msg class对象
         */
        private static final Set<Class<? extends Throwable>> RETRYABLE_EXCEPTIONS = Collections.unmodifiableSet(
                new HashSet<>(Arrays.asList(
                        PSQLException.class,
                        EOFException.class,
                        DataAccessResourceFailureException.class,
                        MyBatisSystemException.class
                ))
        );

        /**
         * 异常msg关键词
         */
        private static final List<String> RETRYABLE_EXCEPTION_MESSAGES = Collections.unmodifiableList(
                Stream.concat(
                        RETRYABLE_EXCEPTIONS.stream().map(Class::getSimpleName),
                        //追加其他关键词
                        Stream.of("I/O error","java.io.EOFException")
                ).collect(Collectors.toList())
        );


        /**
         * 检查当前是否可重试的异常
         * @param e
         * @return
         */
        public static boolean checkIsRetryableException(Throwable e) {
            if (e == null) {
                return false;
            }
            // 检查异常类型
            if (RETRYABLE_EXCEPTIONS.contains(e.getClass())) {
                return true;
            }
            // 检查异常信息中是否包含特定字符串
            String message ;

            try {
                message = getMessageFromExceptionV1(e) + getMessageFromExceptionV2(e);
            } catch (Exception ignore) {
                //出现未知的错误的,都不执行重试
                return false;
            }
            if (StringUtils.isBlank(message)) {
                return false;
            }
            return RETRYABLE_EXCEPTION_MESSAGES.stream().map(String::toLowerCase)
                    .anyMatch(message.toLowerCase()::contains);
        }

        private PostgreSqlExceptionInfo() {
        }
    }

}
正文到此结束
本文目录