实现Cglib动态代理

实现Cglib动态代理

Cglib实现的动态代理比较复杂,我们可以实现一个简单版的Cglib,只实现通过继承被代理类并覆写方法生成动态代理类。

定义方法拦截器

首先定义方法拦截器接口MyMethodInterceptor,如代码清单6-18所示。

代码清单6-18 MyMethodInterceptor方法拦截器

public interface MyMethodInterceptor {

Object intercept(Object obj, Method method, String methodName, Object[] objects)
 throws Throwable;

}

MyMethodInterceptor的intercept方法与cglib的MethodInterceptor的intercept方法有很大的差别。MyMethodInterceptor的intercept方法参数说明:

◆obj:代理类对象,即代理类中调用该方法时传递的this引用。

◆method:代理类提供给拦截器调用其父类方法的Method对象,在intercept方法中可使用该Method调用父类的方法。注意:methodProxy.getName获取到的方法是代理方法的名称。

◆methodName:被代理的方法名称。

◆objects:参数数组,将调用方法传递的参数包装成的一个数组。

实现MyEnhancer

与Cglib一样,我们也通过增强器来创建代理类,因此需要创建一个增强器类MyEnhancer。MyEnhancer的实现如代码清单6-19所示。

代码清单6-19 MyEnhancer类

public class MyEnhancer {

    private Class<?> superclass;
    private MyMethodInterceptor interceptor;

    public void setSuperclass(Class<?> superclass) {
        if (superclass.isInterface()) {
            throw new RuntimeException("父类不能是接口!");
        }
        if (superclass == Object.class) {
            throw new RuntimeException("不能代理Object类!");
        }
if ((superclass.getModifiers() & Modifier.FINAL) == Modifier.FINAL) {
            throw new RuntimeException("final类不允许继承!");
        }
        this.superclass = superclass;
    }

    public void setCallback(MyMethodInterceptor interceptor) {
        this.interceptor = interceptor;
    }

    public Object create() {
        if (superclass == null) {
            throw new RuntimeException("未设置父类!");
        }
        try {
            // 如果拦截器为空,无需创建代理类
            if (interceptor == null) {
                return superclass.newInstance();
            }
            // 代理类的名称
            String subclassName = Type.getInternalName(superclass) + "$Proxy";
            // 创建代理类
            byte[] subclassByteCode = SubclassProxyFactory
.createProxyClass(subclassName, superclass);
            // 使用自定义类加载器加载
Class<?> subclass = ByteCodeUtils.loadClass(subclassName, subclassByteCode);
            // 获取构造器,反射创建对象
Constructor<?> constructor = subclass.getConstructor(MyMethodInterceptor.class);
            return constructor.newInstance(interceptor);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

}

MyEnhancer提供三个方法,一个是setSuperclass方法,用于设置被代理的类(父类),需判断父类是否是接口,以及根据父类的访问标志判断父类是否允许被继承,如果不满足条件则抛出异常。第二个是setCallback方法,用于设置方法拦截器,第三个是create方法,用于生成代理类并通过反射创建一个代理对象。

create方法中,先判断是否设置了方法拦截器,如果没有设置方法拦截器,也就没有创建代理类的必要。接着调用SubclassProxyFactory的createProxyClass方法生成代理类,获取动态生成的代理类的字节数组,拿到代理类的字节数组之后使用自定义的类加载器加载,最后通过反射创建代理类对象。代理类的名称我们使用父类的名称加上“$Proxy”。

实现SubclassProxyFactory

SubclassProxyFactory的createProxyClass方法如代码清单6-20所示。

代码清单6-20 SubclassProxyFactory的createProxyClass方法

public static byte[] createProxyClass(String className, Class<?> superclass) {
        ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_MAXS
| ClassWriter.COMPUTE_FRAMES);
        cw.visit(Opcodes.V1_8, ACC_PUBLIC, className, null,
                Type.getInternalName(superclass),
                null);
        // 创建构造方法
        createInitMethod(cw, className, superclass);
        // 获取需要拦截的方法
        Method[] methods = getProxyMethod(superclass);
        if (methods.length > 0) {
            // 添加静态代码块,生成通过反射获取Method的字节码
            addStaticBlock(cw, className, superclass, methods);
            // 覆写父类的方法
            overrideMethods(cw, className, superclass, methods);
        }
        cw.visitEnd();
        return cw.toByteArray();
    }

在createProxyClass方法中,我们使用ClassWriter生成代理类。首先调用ClassWriter实例的visit方法设置我们要生成的代理类的类名,以及代理类继承的父类。接着调用createInitMethod方法为代理类生成类的实例初始化方法,也就是构造方法。最后调用getProxyMethod方法获取父类所有需要被代理的方法,调用overrideMethods方法为代理类重写父类方法。在重写父类方法之前,调用addStaticBlock方法用于为代理类添加静态字段并生成静态代码块,为静态字段赋值。

createInitMethod方法的实现如代码清单6-21所示。

代码清单6-21 createInitMethod方法

    private static void createInitMethod(ClassWriter cw, String className, Class<?> superclass) {
        MethodVisitor mv = cw.visitMethod(ACC_PUBLIC, "<init>",
                "(" + Type.getDescriptor(MyMethodInterceptor.class) + ")V",
                null, null);
        mv.visitCode();

        // 调用父类的无参实例初始化方法
        mv.visitVarInsn(ALOAD, 0);
        mv.visitMethodInsn(INVOKESPECIAL,
                Type.getInternalName(superclass),
                "<init>", "()V", false);

        // 添加一个字段,字段名:h,字段类型为:MyMethodInterceptor
        cw.visitField(ACC_PRIVATE, "h", Type.getDescriptor(MyMethodInterceptor.class),
                null, null);

        // 为字段赋值
        mv.visitVarInsn(ALOAD, 0);
        mv.visitVarInsn(ALOAD, 1);
        mv.visitFieldInsn(PUTFIELD, className, "h",
                Type.getDescriptor(MyMethodInterceptor.class));

        mv.visitInsn(RETURN);
        mv.visitMaxs(2, 2);
        mv.visitEnd();
    }

我们为代理类生成带一个参数且参数类型为MyMethodInterceptor的实例初始化方法,也就是要求创建代理类时传入方法拦截器。在实例初始化方法中,先调用父类的无参实例初始化方法,再为子类添加一个类型为MyMethodInterceptor且名称为h的字段,将参数赋值给h字段。

getProxyMethod方法的实现如代码清单6-22所示。

代码清单6-22 getProxyMethod方法

private static Method[] getProxyMethod(Class<?> superclass) {
        // 获取superclass类的所有public方法,包括superclass的父类中的public方法
        Method[] methods = superclass.getMethods();
        List<Method> methodList = new ArrayList<>(methods.length);
        for (Method method : methods) {
            // 过滤不需要重写的方法
            if (EXCLUDE_METHOD.contains(method.getName())) {
                continue;
            }
            // 跳过final修饰的方法
            if ((method.getModifiers() & Modifier.FINAL) == Modifier.FINAL) {
                continue;
            }
            methodList.add(method);
        }
        return methodList.toArray(new Method[]{});
    }

如代码清单6-22所示,我们通过反射获取父类所有public方法,过滤不需要代理的方法以及不允许被子类重写的方法。不需要代理的方法如wait、equals等方法,而不允许子类覆写的方法可根据方法的访问标志判断方法是否允许被子类重写。我们不代理以下方法:

private static final List<String> EXCLUDE_METHOD = Arrays.asList("wait", "equals",
            "toString", "hashCode", "getClass", "notify", "notifyAll");

addStaticBlock方法的实现如代码清单6-23所示。

代码清单6-23 addStaticBlock方法

private static void addStaticBlock(ClassWriter cw, String className,
                                       Class<?> superclass, Method[] methods) {
        // 给<clinit>方法添加static访问标志。对应java代码的静态代码块
        MethodVisitor mv = cw.visitMethod(ACC_STATIC, "<clinit>", "()V",
                null, null);
        mv.visitCode();
        for (int i = 0; i < methods.length; i++) {
            Method method = methods[i];

            // 字段名取方法名称+i,避免重载方法的字段名相同的情况
            String fieldName = "_" + method.getName() + "_" + i;
            // 添加静态字段
            cw.visitField(ACC_PRIVATE | ACC_STATIC, fieldName, Type.getDescriptor(Method.class), null, null);

            // 生成一个调用父类方法的代理方法
            addCallSuperclassMethod(cw, superclass, method.getName() + "_" + i, method);

            // 调用Class的forName方法获取this的Class实例
            mv.visitLdcInsn(className.replace("/", "."));
            mv.visitMethodInsn(INVOKESTATIC, Type.getInternalName(Class.class),
                    "forName",
                    "(Ljava/lang/String;)Ljava/lang/Class;",
                    false);

            // 调用Class的getMethod方法,方法需要两个参数,一个是方法名称,一个是方法参数类型数组
            // 参数1
            mv.visitLdcInsn(method.getName() + "_" + i);
            // 参数2
            // 创建数组,并为数组的每个元素赋值
            Class[] methodParamTypes = method.getParameterTypes();
            if (methodParamTypes.length == 0) {
                mv.visitInsn(ACONST_NULL);
            } else {
                // 数组大小
                switch (methodParamTypes.length) {
                    case 1:
                        mv.visitInsn(ICONST_1);
                        break;
                    case 2:
                        mv.visitInsn(ICONST_2);
                        break;
                    case 3:
                        mv.visitInsn(ICONST_3);
                        break;
                    default:
                        mv.visitVarInsn(BIPUSH, methodParamTypes.length);
                }
                mv.visitTypeInsn(ANEWARRAY, Type.getInternalName(Class.class));
                // 为数组元素赋值,数组元素类型为java.lang.Class
                for (int index = 0; index < methodParamTypes.length; index++) {
                    mv.visitInsn(DUP);
                    // 数组元素下标
                    switch (index) {
                        case 0:
                            mv.visitInsn(ICONST_0);
                            break;
                        case 1:
                            mv.visitInsn(ICONST_1);
                            break;
                        case 2:
                            mv.visitInsn(ICONST_2);
                            break;
                        case 3:
                            mv.visitInsn(ICONST_3);
                            break;
                        default:
                            mv.visitVarInsn(BIPUSH, i);
                            break;
                    }
                    mv.visitLdcInsn(methodParamTypes[index].getName());
                    // 调用forName获取参数的Class实例
                    mv.visitMethodInsn(INVOKESTATIC, Type.getInternalName(Class.class),
                            "forName",
                            "(Ljava/lang/String;)Ljava/lang/Class;",
                            false);
                    // 存储到数组
                    mv.visitInsn(AASTORE);
                }
            }
            // 参数准备完毕,调用Class的getMethod方法获取Method
            mv.visitMethodInsn(INVOKEVIRTUAL, Type.getInternalName(Class.class),
                    "getMethod",
                    "(Ljava/lang/String;[Ljava/lang/Class;)Ljava/lang/reflect/Method;", false);
            // 为静态字段赋值
            mv.visitFieldInsn(PUTSTATIC, className, fieldName, Type.getDescriptor(Method.class));
        }
        mv.visitInsn(RETURN);
        mv.visitMaxs(1, 1);
        mv.visitEnd();
    }

addStaticBlock方法中,首先为动态代理类生成类的实例初始化方法<clinit>,将方法的访问标志声明为static,对应java代码中的静态代码块。与实现JDK动态代理有所不同,此静态代码块中通过反射获取的Method对象是代理类提供给方法拦截器调用父类方法的额外方法,该方法是调用addCallSuperclassMethod方法生成的,addCallSuperclassMethod方法如代码清单6-24所示。

代码清单6-24 addCallSuperclassMethod方法

private static void addCallSuperclassMethod(ClassWriter cw, Class<?> superclass,
String methodName, Method method) {
        MethodVisitor mv = cw.visitMethod(ACC_PUBLIC | ACC_FINAL, methodName,
                Type.getMethodDescriptor(method),
                null, null);
        mv.visitCode();
        // this
        mv.visitVarInsn(ALOAD, 0);
        // 参数入栈
        Class<?>[] paramTypes = method.getParameterTypes();
        if (paramTypes.length > 0) {
            for (int i = 0; i < paramTypes.length; i++) {
                Class<?> paramType = paramTypes[i];
                if (paramType == int.class) {
                    mv.visitVarInsn(ILOAD, i + 1);
                } else if (paramType == long.class) {
                    mv.visitVarInsn(LLOAD, i + 1);
                }
                //....
                else {
                    mv.visitVarInsn(ALOAD, i + 1);
                }
            }
        }
        // 调用父类的方法
        mv.visitMethodInsn(INVOKESPECIAL, Type.getInternalName(superclass),
                method.getName(), Type.getMethodDescriptor(method), false);
        // 生成return指令
        addReturnInstruc(mv, method.getReturnType());

        mv.visitMaxs(1, 1);
        mv.visitEnd();
    }

addCallSuperclassMethod方法为代理类生成一个提供给方法拦截器调用代理类父类方法的额外方法,为该方法添加一条调用父类方法的字节码指令。在添加调用父类方法的字节码指令之前,我们需要先添加将调用方法所需的参数入栈的字节码指令。通过反射获取方法的参数类型,根据参数类型以及参数的索引,生成对应的LOAD指令。调用父类方法使用invokespecial指令,且第二个参数传父类的内部类名。

overrideMethods方法的实现如代码清单6-25所示。

代码清单6-25 overrideMethods方法

private static void overrideMethods(ClassWriter cw, String className,
                                        Class<?> superclass, Method[] methods) {
        for (int i = 0; i < methods.length; i++) {
            Method method = methods[i];
            // 覆写父类的方法
            MethodVisitor mv = cw.visitMethod(ACC_PUBLIC, method.getName(),
                    Type.getMethodDescriptor(method),
                    null,
                    new String[]{Type.getInternalName(Exception.class)});
            mv.visitCode();

            Label from = new Label();
            Label to = new Label();
            Label target = new Label();

            // try开始
            mv.visitLabel(from);

            // 获取字段,字段名为h,类型为MyMethodInterceptor
            mv.visitVarInsn(ALOAD, 0);
            mv.visitFieldInsn(GETFIELD,
                    className,
                    "h",
                    Type.getDescriptor(MyMethodInterceptor.class));

            // 准备调用MyMethodInterceptor的intercept方法的三个参数
            // 第一个参数
            mv.visitVarInsn(ALOAD, 0);
            mv.visitTypeInsn(CHECKCAST, Type.getInternalName(superclass));
            // 第二个参数。获取静态字段
            mv.visitFieldInsn(GETSTATIC,
                    className,
                    "_" + method.getName() + "_" + i,
                    Type.getDescriptor(Method.class));

            // 第三个参数
            mv.visitLdcInsn(method.getName());

            // 第四个参数,将当前方法的参数构造成数组
            int paramCount = method.getParameterCount();
            if (paramCount == 0) {
                mv.visitInsn(ACONST_NULL);
            } else {
                // 数组大小
                switch (paramCount) {
                    case 1:
                        mv.visitInsn(ICONST_1);
                        break;
                    case 2:
                        mv.visitInsn(ICONST_2);
                        break;
                    case 3:
                        mv.visitInsn(ICONST_3);
                        break;
                    default:
                        mv.visitVarInsn(BIPUSH, paramCount);
                }
                // 创建数组
                mv.visitTypeInsn(ANEWARRAY, Type.getInternalName(Object.class));
                // 为数组元素赋值
                for (int index = 1; index <= paramCount; index++) {
                    mv.visitInsn(DUP);
                    switch (index - 1) {
                        case 0:
                            mv.visitInsn(ICONST_0);
                            break;
                        case 1:
                            mv.visitInsn(ICONST_1);
                            break;
                        case 2:
                            mv.visitInsn(ICONST_2);
                            break;
                        case 3:
                            mv.visitInsn(ICONST_3);
                            break;
                        default:
                            mv.visitVarInsn(BIPUSH, i - 1);
                            break;
                    }
                    // 暂不考虑参数类型为基本数据类型的情况
                    mv.visitVarInsn(ALOAD, index);
                    mv.visitInsn(AASTORE);
                }
            }

            // 调用MyMethodInterceptor的intercept方法
            mv.visitMethodInsn(INVOKEINTERFACE,
                    Type.getInternalName(MyMethodInterceptor.class),
                    "intercept",
                    "(Ljava/lang/Object;Ljava/lang/reflect/Method;Ljava/lang/String;[Ljava/lang/Object;)Ljava/lang/Object;",
                    true);

            // 添加return指令
            addReturnInstruc(mv, method.getReturnType());

            // try结束
            mv.visitLabel(to);
            // catch开始
            mv.visitLabel(target);
            // 抛出异常
            mv.visitInsn(ATHROW);
            // 添加TryCatch代码块
            mv.visitTryCatchBlock(from, to, target, Type.getInternalName(Exception.class));

            mv.visitFrame(F_FULL, 0, null, 0, null);
            mv.visitMaxs(1, 1);
            mv.visitEnd();
        }
    }

overrideMethods方法实现重写父类方法拦截方法的调用。为重写的方法添加调用方法拦截器intercept方法的字节码指令,并为方法调用添加try-catch。

使用MyEnhancer创建代理对象

使用我们实现的增强器MyEnhancer创建HttpRequestTemplateImpl代理类,如代码清单6-26所示。

代码清单6-26 使用MyEnhancer创建HttpRequestTemplateImpl代理类

     public static void main(String[] args) {
        MyEnhancer enhancer = new MyEnhancer();
        enhancer.setSuperclass(HttpRequestTemplateImpl.class);
        enhancer.setCallback(new HttpRequestMyMethodInterceptor());
        HttpRequestTemplateImpl proxyObj = (HttpRequestTemplateImpl) enhancer.create();
        HttpRequest request = new HttpRequest("http://127.0.0.1:8080/book/list", "GET");
        proxyObj.doGet(request);
    }

HttpRequestMyMethodInterceptor方法拦截器如代码清单6-27所示。

代码清单6-27 HttpRequestMyMethodInterceptor

public class HttpRequestMyMethodInterceptor implements MyMethodInterceptor {

    @Override
    public Object intercept(Object obj, Method methodProxy, String methodName, Object[] objects) throws Throwable {
        long startMs = System.currentTimeMillis();
        try {
            // 使用代理类的代理方法调用父类的方法
            return methodProxy.invoke(obj, objects);
        } finally {
            long cntMs = System.currentTimeMillis() - startMs;
            System.out.println(methodName + "方法的执行耗时为" + cntMs + "毫秒");
        }
    }

}

MyEnhancer生成的HttpRequestTemplateImpl代理类如代码清单6-28所示。

代码清单6-28 HttpRequestTemplateImpl代理类

public class HttpRequestTemplateImpl$Proxy extends HttpRequestTemplateImpl {
    private MyMethodInterceptor h;
private static Method _doGet_0 =
Class.forName("com.wujiuye.asmbytecode.book.sixth.HttpRequestTemplateImpl$Proxy")
.getMethod("doGet_0",
Class.forName("com.wujiuye.asmbytecode.book.sixth.HttpRequest"));
private static Method _doPost_1 =
Class.forName("com.wujiuye.asmbytecode.book.sixth.HttpRequestTemplateImpl$Proxy")
.getMethod("doPost_1", Class.forName("com.wujiuye.asmbytecode.book.sixth.HttpRequest"));

    public HttpRequestTemplateImpl$Proxy(MyMethodInterceptor var1) {
        this.h = var1;
    }

    public final HttpResponse doGet_0(HttpRequest var1) {
        return (HttpResponse)super.doGet(var1);
    }

    public final HttpResponse doPost_1(HttpRequest var1) {
        return (HttpResponse)super.doPost(var1);
    }

    public HttpResponse doGet(HttpRequest var1) throws Exception {
        try {
            return (HttpResponse)this.h.intercept((HttpRequestTemplateImpl)this,
_doGet_0, "doGet", new Object[]{var1});
        } catch (Exception var2) {
            throw var2;
        }
    }

    public HttpResponse doPost(HttpRequest var1) throws Exception {
        try {
            return (HttpResponse)this.h.intercept((HttpRequestTemplateImpl)this,
_doPost_1, "doPost", new Object[]{var1});
        } catch (Exception var2) {
            throw var2;
        }
    }
}

其中doGet_0方法就是我们为代理类额外添加的方法,_doGet_0静态字段持有这个方法的Method对象的引用,在重写父类的doGet方法时,调用方法拦截器的intercept方法将this引用、_doGet_0静态字段,以及调用方法的参数传递给方法拦截器,所以我们可以在方法拦截器中通过method参数调用父类的方法。

从实现JDK动态代理与实现Cglib动态代理的过程中,我们发现,这两种实现动态代理的方式在实现上有很多相同的地方,从字节码层面来看,两个代理方法在性能上是没有多大区别的。如果代理类实现接口,我们可使用代理接口的方式实现动态代理,如果接口的实现类有多个,这种方式我们只需要生成一个代理类。