Index: java/security/AccessController.java
===================================================================
RCS file: /cvsroot/classpath/classpath/java/security/AccessController.java,v
retrieving revision 1.8
diff -u -r1.8 AccessController.java
--- java/security/AccessController.java 3 Jun 2004 09:16:58 -0000 1.8
+++ java/security/AccessController.java 25 Jul 2004 03:02:59 -0000
@@ -95,7 +95,15 @@
*/
public static Object doPrivileged(PrivilegedAction action)
{
- return action.run();
+ VMAccessController.pushContext (null);
+ try
+ {
+ return action.run();
+ }
+ finally
+ {
+ VMAccessController.popContext();
+ }
}
/**
@@ -113,16 +121,16 @@
* @return the result of the action.run()
method.
*/
public static Object doPrivileged(PrivilegedAction action,
- AccessControlContext context)
+ AccessControlContext context)
{
- VMAccessController.pushContext (context, action.getClass());
+ VMAccessController.pushContext (context);
try
{
return action.run();
}
finally
{
- VMAccessController.popContext (action.getClass());
+ VMAccessController.popContext();
}
}
@@ -145,14 +153,18 @@
public static Object doPrivileged(PrivilegedExceptionAction action)
throws PrivilegedActionException
{
-
+ VMAccessController.pushContext (null);
try
{
- return action.run();
+ return action.run();
}
catch (Exception e)
{
- throw new PrivilegedActionException(e);
+ throw new PrivilegedActionException(e);
+ }
+ finally
+ {
+ VMAccessController.popContext();
}
}
@@ -175,22 +187,22 @@
* is thrown in the run()
method.
*/
public static Object doPrivileged(PrivilegedExceptionAction action,
- AccessControlContext context)
+ AccessControlContext context)
throws PrivilegedActionException
{
- VMAccessController.pushContext (context, action.getClass());
+ VMAccessController.pushContext (context);
try
{
- return action.run();
+ return action.run();
}
catch (Exception e)
{
- throw new PrivilegedActionException(e);
+ throw new PrivilegedActionException(e);
}
finally
{
- VMAccessController.popContext (action.getClass());
+ VMAccessController.popContext();
}
}
Index: vm/reference/java/security/VMAccessController.java
===================================================================
RCS file: /cvsroot/classpath/classpath/vm/reference/java/security/VMAccessController.java,v
retrieving revision 1.2
diff -u -r1.2 VMAccessController.java
--- vm/reference/java/security/VMAccessController.java 4 Jul 2004 18:32:32 -0000 1.2
+++ vm/reference/java/security/VMAccessController.java 25 Jul 2004 03:02:59 -0000
@@ -52,15 +52,24 @@
// -------------------------------------------------------------------------
/**
- * A mapping between pairs (thread, classname) to access
- * control contexts. The thread and classname are the thread
- * and classname current as of the last call to doPrivileged with
- * an AccessControlContext argument.
+ * This is a per-thread stack of AccessControlContext objects (which can
+ * be null) for each call to AccessController.doPrivileged in each thread's
+ * call stack. We use this to remember which context object corresponds to
+ * which call.
*/
- private static final Map contexts = Collections.synchronizedMap(new HashMap());
+ private static final ThreadLocal contexts = new ThreadLocal();
+ /**
+ * This is a Boolean that, if set, tells getContext that it has already
+ * been called once, allowing us to handle recursive permission checks
+ * caused by methods getContext calls.
+ */
private static final ThreadLocal inGetContext = new ThreadLocal();
+ /**
+ * And we return this all-permissive context to ensure that privileged
+ * methods called from getContext succeed.
+ */
private final static AccessControlContext DEFAULT_CONTEXT;
static
{
@@ -97,15 +106,17 @@
* pushed from one thread will not be available to another.
*
* @param acc The access control context.
- * @param clazz The class that implements address@hidden PrivilegedAction}.
*/
- static void pushContext (AccessControlContext acc, Class clazz)
+ static void pushContext (AccessControlContext acc)
{
- ArrayList pair = new ArrayList (2);
- pair.add (Thread.currentThread());
- pair.add (clazz);
- if (DEBUG) debug ("pushing " + pair);
- contexts.put (pair, acc);
+ if (DEBUG) debug ("pushing " + acc);
+ LinkedList stack = (LinkedList) contexts.get();
+ if (stack == null)
+ {
+ stack = new LinkedList();
+ contexts.set (stack);
+ }
+ stack.addFirst (acc);
}
/**
@@ -113,16 +124,20 @@
* This method is used by address@hidden AccessController} when exiting from a
* call to address@hidden
* AccessController#doPrivileged(java.security.PrivilegedAction,java.security.AccessControlContext)}.
- *
- * @param clazz The class that implements address@hidden PrivilegedAction}.
*/
- static void popContext (Class clazz)
+ static void popContext()
{
- ArrayList pair = new ArrayList (2);
- pair.add (Thread.currentThread());
- pair.add (clazz);
- if (DEBUG) debug ("popping " + pair);
- contexts.remove (pair);
+ if (DEBUG) debug ("popping context");
+
+ // Stack should never be null, nor should it be empty, if this method
+ // and its counterpart has been called properly.
+ LinkedList stack = (LinkedList) contexts.get();
+ if (stack != null)
+ {
+ stack.removeFirst();
+ if (stack.isEmpty())
+ contexts.set (null);
+ }
}
/**
@@ -147,24 +162,25 @@
return DEFAULT_CONTEXT;
}
+ inGetContext.set (Boolean.TRUE);
+
Object[][] stack = getStack();
Class[] classes = (Class[]) stack[0];
String[] methods = (String[]) stack[1];
- inGetContext.set (Boolean.TRUE);
-
if (DEBUG) debug (">>> got trace of length " + classes.length);
HashSet domains = new HashSet();
HashSet seenDomains = new HashSet();
AccessControlContext context = null;
+ int privileged = 0;
// We walk down the stack, adding each ProtectionDomain for each
// class in the call stack. If we reach a call to doPrivileged,
// we don't add any more stack frames. We skip the first three stack
// frames, since they comprise the calls to getStack, getContext,
// and AccessController.getContext.
- for (int i = 3; i < classes.length; i++)
+ for (int i = 3; i < classes.length && privileged < 2; i++)
{
Class clazz = classes[i];
String method = methods[i];
@@ -175,17 +191,17 @@
debug (">>> loader = " + clazz.getClassLoader());
}
+ if (privileged == 1)
+ privileged = 2;
if (clazz.equals (AccessController.class)
&& method.equals ("doPrivileged"))
{
// If there was a call to doPrivileged with a supplied context,
// return that context.
- List pair = new ArrayList(2);
- pair.add (Thread.currentThread());
- pair.add (classes[i-1]);
- if (contexts.containsKey (pair))
- context = (AccessControlContext) contexts.get (pair);
- break;
+ LinkedList l = (LinkedList) contexts.get();
+ if (l != null)
+ context = (AccessControlContext) l.getFirst();
+ privileged = 1;
}
ProtectionDomain domain = clazz.getProtectionDomain();