Index: java/security/AccessController.java =================================================================== RCS file: /cvsroot/classpath/classpath/java/security/AccessController.java,v retrieving revision 1.5 diff -u -w -r1.5 AccessController.java --- java/security/AccessController.java 6 May 2002 16:19:20 -0000 1.5 +++ java/security/AccessController.java 10 Feb 2004 18:52:38 -0000 @@ -1,5 +1,5 @@ /* AccessController.java --- Access control context and permission checker - Copyright (C) 2001 Free Software Foundation, Inc. + Copyright (C) 2001,2004 Free Software Foundation, Inc. This file is part of GNU Classpath. @@ -37,6 +37,15 @@ package java.security; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; + /** * Access control context and permission checker. * Can check permissions in the access control context of the current thread @@ -58,6 +67,27 @@ */ public final class AccessController { + + /** + * A mapping between pairs (thread, classname) to access + * control contexts. The thread and classname are the thread + * and classname current as of the last call to doPrivileged with + * an AccessControlContext argument. + */ + private final static Map contexts = Collections.synchronizedMap(new HashMap()); + + private final static AccessControlContext defaultContext; + static + { + CodeSource source = new CodeSource(null, null); + Permissions permissions = new Permissions(); + permissions.add(new AllPermission()); + ProtectionDomain[] domain = new ProtectionDomain[] { + new ProtectionDomain(source, permissions) + }; + defaultContext = new AccessControlContext(domain); + } + /** * This class only has static methods so there is no public contructor. */ @@ -115,8 +145,19 @@ public static Object doPrivileged(PrivilegedAction action, AccessControlContext context) { + final List pair = new ArrayList(2); + pair.add(Thread.currentThread()); + pair.add(action.getClass().getName()); + contexts.put(pair, context); + try + { return action.run(); } + finally + { + contexts.remove(pair); + } + } /** * Calls the run() method of the given action with as @@ -170,6 +211,10 @@ AccessControlContext context) throws PrivilegedActionException { + final List pair = new ArrayList(2); + pair.add(Thread.currentThread()); + pair.add(action.getClass().getName()); + contexts.put(pair, context); try { @@ -179,6 +224,10 @@ { throw new PrivilegedActionException(e); } + finally + { + contexts.remove(pair); + } } /** @@ -191,7 +240,61 @@ */ public static AccessControlContext getContext() { - // For now just return an new empty context - return new AccessControlContext(new ProtectionDomain[0]); + List domains = new LinkedList(); + Set classes = new HashSet(); + StackTraceElement[] stack = new Throwable().getStackTrace(); + + // Calls to Class.getProtectionDomain will recursively call this + // method, so we return a default context (one that is guaranteed + // to allow this class access) if we find that we are already in + // a call to getContext. + for (int i = 1; i < stack.length; i++) + { + if (stack[i].getClassName().equals(AccessController.class.getName()) + && stack[i].getMethodName().equals("getContext")) + return defaultContext; + } + + // We walk down the stack, adding each ProtectionDomain for each + // class in the call stack. If we reach a call to doPrivileged, + // we don't add any more stack frames. We skip the first stack frame + // since it is the call to getContext itself. + for (int i = 1; i < stack.length; i++) + { + // Don't add a class twice. + if (classes.contains(stack[i].getClassName())) + continue; + + classes.add(stack[i].getClassName()); + + if (stack[i].getClassName().equals(AccessController.class.getName()) + && stack[i].getMethodName().equals("doPrivileged")) + break; + + // If there was a call to doPrivileged with a supplied context, + // return that context. + List pair = new ArrayList(2); + pair.add(Thread.currentThread()); + pair.add(stack[i].getClassName()); + if (contexts.containsKey(pair)) + { + return (AccessControlContext) contexts.get(pair); + } + + try + { + Class clazz = Class.forName(stack[i].getClassName()); + ProtectionDomain domain = clazz.getProtectionDomain(); + if (domain != null) + domains.add(domain); + } + catch (Exception x) + { + // XXX what to do if this fails? + } + } + + return new AccessControlContext((ProtectionDomain[]) + domains.toArray(new ProtectionDomain[domains.size()])); } }