001    package ca.discotek.feenix.asm;
002    
003    import java.io.File;
004    import java.io.FileInputStream;
005    import java.io.FileOutputStream;
006    import java.io.IOException;
007    
008    import ca.discotek.feenix.ClassManager;
009    import ca.discotek.feenix.Util;
010    import ca.discotek.rebundled.org.objectweb.asm.ClassReader;
011    import ca.discotek.rebundled.org.objectweb.asm.ClassVisitor;
012    import ca.discotek.rebundled.org.objectweb.asm.ClassWriter;
013    import ca.discotek.rebundled.org.objectweb.asm.FieldVisitor;
014    import ca.discotek.rebundled.org.objectweb.asm.Label;
015    import ca.discotek.rebundled.org.objectweb.asm.MethodVisitor;
016    import ca.discotek.rebundled.org.objectweb.asm.Opcodes;
017    import ca.discotek.rebundled.org.objectweb.asm.Type;
018    import ca.discotek.rebundled.org.objectweb.asm.commons.LocalVariablesSorter;
019    import ca.discotek.rebundled.org.objectweb.asm.util.CheckClassAdapter;
020    
021    public class ModifyClassVisitor extends ClassVisitor {
022    
023        public static final String CHECK_UPDATE_METHOD_NAME = "__check_update__";
024        public static final String INTER_VAR_NAME = "__inter__";
025        
026        String className;
027        String interfaceName;
028        String interfaceNameDesc;
029        
030        public ModifyClassVisitor(ClassVisitor cv) {
031            super(Opcodes.ASM5, cv);
032        }
033        
034        public static byte[] generate(String className, File file) {
035            ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_MAXS);
036            ModifyClassVisitor cv = new ModifyClassVisitor(cw);
037            
038            FileInputStream fis = null;
039            try {
040                fis = new FileInputStream(file);
041                ClassReader cr = new ClassReader(fis);
042                cr.accept(cv, ClassReader.SKIP_FRAMES);
043                
044                return cw.toByteArray();
045            } 
046            catch (Exception e) {
047                e.printStackTrace();
048                return null;
049            }
050            finally {
051                if (fis != null) {
052                    try { fis.close(); } 
053                    catch (IOException e) {
054                        // no other option - log
055                        e.printStackTrace();
056                    }
057                }
058            }
059        }
060        
061        public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
062            super.visit(version, access, name, signature, superName, interfaces);
063            this.className = name;
064            this.interfaceName = ClassManager.getInterfaceName(className);
065            this.interfaceNameDesc = "L" + interfaceName + ";";
066            
067            FieldVisitor fv = super.visitField(Opcodes.ACC_PRIVATE | Opcodes.ACC_STATIC, INTER_VAR_NAME, interfaceNameDesc, null, null);
068            fv.visitEnd();
069            
070            MethodVisitor mv = super.visitMethod(Opcodes.ACC_PROTECTED | Opcodes.ACC_STATIC, CHECK_UPDATE_METHOD_NAME, "()V", null, null);
071            LocalVariablesSorter lvs = new LocalVariablesSorter(Opcodes.ACC_PROTECTED, "()V", mv);
072            lvs.visitCode();
073            
074            lvs.visitLdcInsn(Type.getType("L" + className + ";"));
075            lvs.visitMethodInsn(Opcodes.INVOKESTATIC, Util.toSlashName(ClassManager.class), "getUpdate", "(Ljava/lang/Class;)Ljava/lang/Object;", false);
076            lvs.visitInsn(Opcodes.DUP);
077            
078            int index = lvs.newLocal(Type.getType(interfaceName));
079            lvs.visitVarInsn(Opcodes.ASTORE, index);
080            lvs.visitVarInsn(Opcodes.ALOAD, index);
081            Label l0 = new Label();
082            lvs.visitJumpInsn(Opcodes.IFNULL, l0);
083            lvs.visitVarInsn(Opcodes.ALOAD, index);
084            mv.visitFieldInsn(Opcodes.PUTSTATIC, className, INTER_VAR_NAME, interfaceNameDesc);
085            mv.visitLabel(l0);
086            
087            mv.visitInsn(Opcodes.RETURN);
088            mv.visitMaxs(2, 1);
089            
090            mv.visitEnd();
091        }
092    
093        public FieldVisitor visitField(int access, String name, String desc, String signature, Object value) {
094            int newAccess = Util.convertToPublicAccess(access);
095            newAccess = Util.removeAttribute(newAccess, Opcodes.ACC_FINAL);
096            return super.visitField(newAccess, name, desc, signature, value);
097        }
098        
099        
100        int index = 0;
101    
102        public MethodVisitor visitMethod(int access, String name, String desc, String signature, String[] exceptions) {
103            MethodVisitor mv = super.visitMethod(access, name, desc, signature, exceptions);
104            try {
105                return new ModifyMethodVisitor(index, name.equals("<init>"), mv, access, name, desc);            
106            }
107            finally {
108                index++;
109            }
110        }
111        
112        class ModifyMethodVisitor extends LocalVariablesSorter {
113    
114            final int invokeIndex;
115            final String name; 
116            final String desc;
117            final boolean isStatic;
118            final boolean isConstructor;
119            
120            public ModifyMethodVisitor(int invokeIndex, boolean isConstructor, MethodVisitor mv, int access, String name, String desc) {
121                super(Opcodes.ASM5, access, desc, mv);
122                this.invokeIndex = invokeIndex;
123                this.isConstructor = isConstructor;
124                this.name = name;
125                this.desc = desc;
126                isStatic = Util.isEnabled(access, Opcodes.ACC_STATIC);
127            }
128    
129            public void visitCode() {
130                super.visitCode();
131                mv.visitMethodInsn(Opcodes.INVOKESTATIC, className, CHECK_UPDATE_METHOD_NAME, "()V", false);
132                mv.visitFieldInsn(Opcodes.GETSTATIC, className, INTER_VAR_NAME, interfaceNameDesc);
133                
134                Label l0 = new Label();
135                mv.visitJumpInsn(Opcodes.IFNULL, l0);
136                mv.visitFieldInsn(Opcodes.GETSTATIC, className, INTER_VAR_NAME, interfaceNameDesc);
137                
138                
139                if (isStatic)
140                    mv.visitInsn(Opcodes.ACONST_NULL);
141                else
142                    mv.visitVarInsn(Opcodes.ALOAD, 0);
143                
144                Type methodType = Type.getMethodType(desc);
145                Type argTypes[] = methodType.getArgumentTypes();
146    
147                mv.visitFieldInsn(Opcodes.GETSTATIC, className, INTER_VAR_NAME, interfaceNameDesc);
148                mv.visitLdcInsn(invokeIndex);
149                if (isStatic)
150                    mv.visitInsn(Opcodes.ACONST_NULL);
151                else
152                    mv.visitVarInsn(Opcodes.ALOAD, 0);
153                mv.visitLdcInsn(argTypes.length);
154                mv.visitTypeInsn(Opcodes.ANEWARRAY, "java/lang/Object");
155                
156                for (int i=0; i<argTypes.length; i++) {
157                    mv.visitInsn(Opcodes.DUP);
158                    mv.visitLdcInsn(i);
159                    mv.visitVarInsn(Opcodes.ALOAD, isStatic ? i : i+1);  // no +1 if static
160                    mv.visitInsn(Opcodes.AASTORE);
161                }
162                
163                mv.visitMethodInsn(Opcodes.INVOKEINTERFACE, interfaceName, InterfaceGenerator.INVOKE_METHOD_NAME, "(IL" + className + ";[Ljava/lang/Object;)Ljava/lang/Object;", true);
164                
165                Type returnType = methodType.getReturnType();
166                if (returnType.getSort() == Type.VOID) {
167                    mv.visitInsn(Opcodes.POP);
168                    mv.visitInsn(Opcodes.RETURN);
169                }
170                else {
171                    Util.insertCheckCast(returnType, mv);
172                    Util.insertReturnType(mv, returnType);
173                }
174                
175                mv.visitLabel(l0);
176            }
177        }
178    
179    }
180