>From a17eb9c48b6ad5f6144f4f6b4f96d2a3892dc46c Mon Sep 17 00:00:00 2001 From: David Kastrup Date: Sat, 20 Jun 2015 19:44:09 +0200 Subject: [PATCH] Replace upload.py, also let it use auth2 by default --- upload.py | 489 +++++++++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 407 insertions(+), 82 deletions(-) mode change 100755 => 100644 upload.py diff --git a/upload.py b/upload.py old mode 100755 new mode 100644 index 3a7f238..071fbee --- a/upload.py +++ b/upload.py @@ -34,6 +34,7 @@ against by using the '--rev' option. # This code is derived from appcfg.py in the App Engine SDK (open source), # and from ASPN recipe #146306. +import BaseHTTPServer import ConfigParser import cookielib import errno @@ -51,6 +52,9 @@ import sys import urllib import urllib2 import urlparse +import webbrowser + +from multiprocessing.pool import ThreadPool # The md5 module was deprecated in Python 2.5. try: @@ -74,6 +78,7 @@ except ImportError: # 2: Info logs. # 3: Debug logs. verbosity = 1 +LOGGER = logging.getLogger('upload') # The account type used for authentication. # This line could be changed by the review server (see handler for @@ -87,6 +92,7 @@ DEFAULT_REVIEW_SERVER = "codereview.appspot.com" # Max size of patch or base file. MAX_UPLOAD_SIZE = 900 * 1024 + # Constants for version control names. Used by GuessVCSName. VCS_GIT = "Git" VCS_MERCURIAL = "Mercurial" @@ -95,16 +101,72 @@ VCS_PERFORCE = "Perforce" VCS_CVS = "CVS" VCS_UNKNOWN = "Unknown" -VCS_ABBREVIATIONS = { - VCS_MERCURIAL.lower(): VCS_MERCURIAL, - "hg": VCS_MERCURIAL, - VCS_SUBVERSION.lower(): VCS_SUBVERSION, - "svn": VCS_SUBVERSION, - VCS_PERFORCE.lower(): VCS_PERFORCE, - "p4": VCS_PERFORCE, - VCS_GIT.lower(): VCS_GIT, - VCS_CVS.lower(): VCS_CVS, -} +VCS = [ +{ + 'name': VCS_MERCURIAL, + 'aliases': ['hg', 'mercurial'], +}, { + 'name': VCS_SUBVERSION, + 'aliases': ['svn', 'subversion'], +}, { + 'name': VCS_PERFORCE, + 'aliases': ['p4', 'perforce'], +}, { + 'name': VCS_GIT, + 'aliases': ['git'], +}, { + 'name': VCS_CVS, + 'aliases': ['cvs'], +}] + +VCS_SHORT_NAMES = [] # hg, svn, ... +VCS_ABBREVIATIONS = {} # alias: name, ... +for vcs in VCS: + VCS_SHORT_NAMES.append(min(vcs['aliases'], key=len)) + VCS_ABBREVIATIONS.update((alias, vcs['name']) for alias in vcs['aliases']) + + +# OAuth 2.0-Related Constants +LOCALHOST_IP = '127.0.0.1' +DEFAULT_OAUTH2_PORT = 8001 +ACCESS_TOKEN_PARAM = 'access_token' +ERROR_PARAM = 'error' +OAUTH_DEFAULT_ERROR_MESSAGE = 'OAuth 2.0 error occurred.' +OAUTH_PATH = '/get-access-token' +OAUTH_PATH_PORT_TEMPLATE = OAUTH_PATH + '?port=%(port)d' +AUTH_HANDLER_RESPONSE = """\ + + + Authentication Status + + + +

The authentication flow has completed.

+ + +""" +# Borrowed from google-api-python-client +OPEN_LOCAL_MESSAGE_TEMPLATE = """\ +Your browser has been opened to visit: + + %s + +If your browser is on a different machine then exit and re-run +upload.py with the command-line parameter + + --no_oauth2_webbrowser +""" +NO_OPEN_LOCAL_MESSAGE_TEMPLATE = """\ +Go to the following link in your browser: + + %s + +and copy the access token. +""" # The result of parsing Subversion's [auto-props] setting. svn_auto_props_map = None @@ -179,8 +241,9 @@ class ClientLoginError(urllib2.HTTPError): class AbstractRpcServer(object): """Provides a common interface for a simple RPC server.""" - def __init__(self, host, auth_function, host_override=None, extra_headers={}, - save_cookies=False, account_type=AUTH_ACCOUNT_TYPE): + def __init__(self, host, auth_function, host_override=None, + extra_headers=None, save_cookies=False, + account_type=AUTH_ACCOUNT_TYPE): """Creates a new AbstractRpcServer. Args: @@ -203,14 +266,14 @@ class AbstractRpcServer(object): self.host_override = host_override self.auth_function = auth_function self.authenticated = False - self.extra_headers = extra_headers + self.extra_headers = extra_headers or {} self.save_cookies = save_cookies self.account_type = account_type self.opener = self._GetOpener() if self.host_override: - logging.info("Server: %s; Host: %s", self.host, self.host_override) + LOGGER.info("Server: %s; Host: %s", self.host, self.host_override) else: - logging.info("Server: %s", self.host) + LOGGER.info("Server: %s", self.host) def _GetOpener(self): """Returns an OpenerDirector for making HTTP requests. @@ -222,7 +285,7 @@ class AbstractRpcServer(object): def _CreateRequest(self, url, data=None): """Creates a new urllib request.""" - logging.debug("Creating request for: '%s' with payload:\n%s", url, data) + LOGGER.debug("Creating request for: '%s' with payload:\n%s", url, data) req = urllib2.Request(url, data=data, headers={"Accept": "text/plain"}) if self.host_override: req.add_header("Host", self.host_override) @@ -379,7 +442,7 @@ class AbstractRpcServer(object): """ # TODO: Don't require authentication. Let the server say # whether it is necessary. - if not self.authenticated: + if not self.authenticated and self.auth_function: self._Authenticate() old_timeout = socket.getdefaulttimeout() @@ -398,7 +461,7 @@ class AbstractRpcServer(object): for header, value in extra_headers.items(): req.add_header(header, value) try: - f = self.opener.open(req) + f = self.opener.open(req, timeout=70) response = f.read() f.close() return response @@ -406,6 +469,8 @@ class AbstractRpcServer(object): if tries > 3: raise elif e.code == 401 or e.code == 302: + if not self.auth_function: + raise self._Authenticate() elif e.code == 301: # Handle permanent redirect manually. @@ -413,7 +478,9 @@ class AbstractRpcServer(object): url_loc = urlparse.urlparse(url) self.host = '%s://%s' % (url_loc[0], url_loc[1]) elif e.code >= 500: - ErrorExit(e.read()) + # TODO: We should error out on a 500, but the server is too flaky + # for that at the moment. + StatusUpdate('Upload got a 500 response: %d' % e.code) else: raise finally: @@ -425,10 +492,16 @@ class HttpRpcServer(AbstractRpcServer): def _Authenticate(self): """Save the cookie jar after authentication.""" - super(HttpRpcServer, self)._Authenticate() - if self.save_cookies: - StatusUpdate("Saving authentication cookies to %s" % self.cookie_file) - self.cookie_jar.save() + if isinstance(self.auth_function, OAuth2Creds): + access_token = self.auth_function() + if access_token is not None: + self.extra_headers['Authorization'] = 'OAuth %s' % (access_token,) + self.authenticated = True + else: + super(HttpRpcServer, self)._Authenticate() + if self.save_cookies: + StatusUpdate("Saving authentication cookies to %s" % self.cookie_file) + self.cookie_jar.save() def _GetOpener(self): """Returns an OpenerDirector that supports cookies and ignores redirects. @@ -495,7 +568,8 @@ class CondensedHelpFormatter(optparse.IndentedHelpFormatter): parser = optparse.OptionParser( - usage="%prog [options] [-- diff_options] [path...]", + usage=("%prog [options] [-- diff_options] [path...]\n" + "See also: https://github.com/rietveld-codereview/rietveld/wiki/upload.py-Usage"), add_help_option=False, formatter=CondensedHelpFormatter() ) @@ -531,12 +605,26 @@ group.add_option("-H", "--host", action="store", dest="host", group.add_option("--no_cookies", action="store_false", dest="save_cookies", default=True, help="Do not save authentication cookies to local disk.") +group.add_option("--oauth2", action="store_true", + dest="use_oauth2", default=True, + help="Use OAuth 2.0 instead of a password.") +group.add_option("--oauth2_port", action="store", type="int", + dest="oauth2_port", default=DEFAULT_OAUTH2_PORT, + help=("Port to use to handle OAuth 2.0 redirect. Must be an " + "integer in the range 1024-49151, defaults to " + "'%default'.")) +group.add_option("--no_oauth2_webbrowser", action="store_false", + dest="open_oauth2_local_webbrowser", default=True, + help="Don't open a browser window to get an access token.") group.add_option("--account_type", action="store", dest="account_type", metavar="TYPE", default=AUTH_ACCOUNT_TYPE, choices=["GOOGLE", "HOSTED"], help=("Override the default account type " "(defaults to '%default', " "valid choices are 'GOOGLE' and 'HOSTED').")) +group.add_option("-j", "--number-parallel-uploads", + dest="num_upload_threads", default=8, + help="Number of uploads to do in parallel.") # Issue group = parser.add_option_group("Issue options") group.add_option("-t", "--title", action="store", dest="title", @@ -581,8 +669,8 @@ group.add_option("-p", "--send_patch", action="store_true", "attachment, and prepend email subject with 'PATCH:'.") group.add_option("--vcs", action="store", dest="vcs", metavar="VCS", default=None, - help=("Version control system (optional, usually upload.py " - "already guesses the right VCS).")) + help=("Explicitly specify version control system (%s)" + % ", ".join(VCS_SHORT_NAMES))) group.add_option("--emulate_svn_auto_props", action="store_true", dest="emulate_svn_auto_props", default=False, help=("Emulate Subversion's auto properties feature.")) @@ -590,8 +678,11 @@ group.add_option("--emulate_svn_auto_props", action="store_true", group = parser.add_option_group("Git-specific options") group.add_option("--git_similarity", action="store", dest="git_similarity", metavar="SIM", type="int", default=50, - help=("Set the minimum similarity index for detecting renames " - "and copies. See `git diff -C`. (default 50).")) + help=("Set the minimum similarity percentage for detecting " + "renames and copies. See `git diff -C`. (default 50).")) +group.add_option("--git_only_search_patch", action="store_false", default=True, + dest='git_find_copies_harder', + help="Removes --find-copies-harder when seaching for copies") group.add_option("--git_no_find_copies", action="store_false", default=True, dest="git_find_copies", help=("Prevents git from looking for copies (default off).")) @@ -612,10 +703,158 @@ group.add_option("--p4_user", action="store", dest="p4_user", help=("Perforce user")) +# OAuth 2.0 Methods and Helpers +class ClientRedirectServer(BaseHTTPServer.HTTPServer): + """A server for redirects back to localhost from the associated server. + + Waits for a single request and parses the query parameters for an access token + or an error and then stops serving. + """ + access_token = None + error = None + + +class ClientRedirectHandler(BaseHTTPServer.BaseHTTPRequestHandler): + """A handler for redirects back to localhost from the associated server. + + Waits for a single request and parses the query parameters into the server's + access_token or error and then stops serving. + """ + + def SetResponseValue(self): + """Stores the access token or error from the request on the server. + + Will only do this if exactly one query parameter was passed in to the + request and that query parameter used 'access_token' or 'error' as the key. + """ + query_string = urlparse.urlparse(self.path).query + query_params = urlparse.parse_qs(query_string) + + if len(query_params) == 1: + if query_params.has_key(ACCESS_TOKEN_PARAM): + access_token_list = query_params[ACCESS_TOKEN_PARAM] + if len(access_token_list) == 1: + self.server.access_token = access_token_list[0] + else: + error_list = query_params.get(ERROR_PARAM, []) + if len(error_list) == 1: + self.server.error = error_list[0] + + def do_GET(self): + """Handle a GET request. + + Parses and saves the query parameters and prints a message that the server + has completed its lone task (handling a redirect). + + Note that we can't detect if an error occurred. + """ + self.send_response(200) + self.send_header('Content-type', 'text/html') + self.end_headers() + self.SetResponseValue() + self.wfile.write(AUTH_HANDLER_RESPONSE) + + def log_message(self, format, *args): + """Do not log messages to stdout while running as command line program.""" + pass + + +def OpenOAuth2ConsentPage(server=DEFAULT_REVIEW_SERVER, + port=DEFAULT_OAUTH2_PORT): + """Opens the OAuth 2.0 consent page or prints instructions how to. + + Uses the webbrowser module to open the OAuth server side page in a browser. + + Args: + server: String containing the review server URL. Defaults to + DEFAULT_REVIEW_SERVER. + port: Integer, the port where the localhost server receiving the redirect + is serving. Defaults to DEFAULT_OAUTH2_PORT. + + Returns: + A boolean indicating whether the page opened successfully. + """ + path = OAUTH_PATH_PORT_TEMPLATE % {'port': port} + parsed_url = urlparse.urlparse(server) + scheme = parsed_url[0] or 'https' + if scheme != 'https': + ErrorExit('Using OAuth requires a review server with SSL enabled.') + # If no scheme was given on command line the server address ends up in + # parsed_url.path otherwise in netloc. + host = parsed_url[1] or parsed_url[2] + page = '%s://%s%s' % (scheme, host, path) + page_opened = webbrowser.open(page, new=1, autoraise=True) + if page_opened: + print OPEN_LOCAL_MESSAGE_TEMPLATE % (page,) + return page_opened + + +def WaitForAccessToken(port=DEFAULT_OAUTH2_PORT): + """Spins up a simple HTTP Server to handle a single request. + + Intended to handle a single redirect from the production server after the + user authenticated via OAuth 2.0 with the server. + + Args: + port: Integer, the port where the localhost server receiving the redirect + is serving. Defaults to DEFAULT_OAUTH2_PORT. + + Returns: + The access token passed to the localhost server, or None if no access token + was passed. + """ + httpd = ClientRedirectServer((LOCALHOST_IP, port), ClientRedirectHandler) + # Wait to serve just one request before deferring control back + # to the caller of wait_for_refresh_token + httpd.handle_request() + if httpd.access_token is None: + ErrorExit(httpd.error or OAUTH_DEFAULT_ERROR_MESSAGE) + return httpd.access_token + + +def GetAccessToken(server=DEFAULT_REVIEW_SERVER, port=DEFAULT_OAUTH2_PORT, + open_local_webbrowser=True): + """Gets an Access Token for the current user. + + Args: + server: String containing the review server URL. Defaults to + DEFAULT_REVIEW_SERVER. + port: Integer, the port where the localhost server receiving the redirect + is serving. Defaults to DEFAULT_OAUTH2_PORT. + open_local_webbrowser: Boolean, defaults to True. If set, opens a page in + the user's browser. + + Returns: + A string access token that was sent to the local server. If the serving page + via WaitForAccessToken does not receive an access token, this method + returns None. + """ + access_token = None + if open_local_webbrowser: + page_opened = OpenOAuth2ConsentPage(server=server, port=port) + if page_opened: + try: + access_token = WaitForAccessToken(port=port) + except socket.error, e: + print 'Can\'t start local webserver. Socket Error: %s\n' % (e.strerror,) + + if access_token is None: + # TODO(dhermes): Offer to add to clipboard using xsel, xclip, pbcopy, etc. + page = 'https://%s%s' % (server, OAUTH_PATH) + print NO_OPEN_LOCAL_MESSAGE_TEMPLATE % (page,) + access_token = raw_input('Enter access token: ').strip() + + return access_token + + class KeyringCreds(object): def __init__(self, server, host, email): self.server = server - self.host = host + # Explicitly cast host to str to work around bug in old versions of Keyring + # (versions before 0.10). Even though newer versions of Keyring fix this, + # some modern linuxes (such as Ubuntu 12.04) still bundle a version with + # the bug. + self.host = str(host) self.email = email self.accounts_seen = set() @@ -653,8 +892,24 @@ class KeyringCreds(object): return (email, password) +class OAuth2Creds(object): + """Simple object to hold server and port to be passed to GetAccessToken.""" + + def __init__(self, server, port, open_local_webbrowser=True): + self.server = server + self.port = port + self.open_local_webbrowser = open_local_webbrowser + + def __call__(self): + """Uses stored server and port to retrieve OAuth 2.0 access token.""" + return GetAccessToken(server=self.server, port=self.port, + open_local_webbrowser=self.open_local_webbrowser) + + def GetRpcServer(server, email=None, host_override=None, save_cookies=True, - account_type=AUTH_ACCOUNT_TYPE): + account_type=AUTH_ACCOUNT_TYPE, use_oauth2=False, + oauth2_port=DEFAULT_OAUTH2_PORT, + open_oauth2_local_webbrowser=True): """Returns an instance of an AbstractRpcServer. Args: @@ -665,17 +920,22 @@ def GetRpcServer(server, email=None, host_override=None, save_cookies=True, save_cookies: Whether authentication cookies should be saved to disk. account_type: Account type for authentication, either 'GOOGLE' or 'HOSTED'. Defaults to AUTH_ACCOUNT_TYPE. + use_oauth2: Boolean indicating whether OAuth 2.0 should be used for + authentication. + oauth2_port: Integer, the port where the localhost server receiving the + redirect is serving. Defaults to DEFAULT_OAUTH2_PORT. + open_oauth2_local_webbrowser: Boolean, defaults to True. If True and using + OAuth, this opens a page in the user's browser to obtain a token. Returns: A new HttpRpcServer, on which RPC calls can be made. """ - # If this is the dev_appserver, use fake authentication. host = (host_override or server).lower() if re.match(r'(http://)?localhost([:/]|$)', host): if email is None: email = "address@hidden" - logging.info("Using debug user %s. Override with --email" % email) + LOGGER.info("Using debug user %s. Override with --email" % email) server = HttpRpcServer( server, lambda: (email, "password"), @@ -688,8 +948,13 @@ def GetRpcServer(server, email=None, host_override=None, save_cookies=True, server.authenticated = True return server - return HttpRpcServer(server, - KeyringCreds(server, host, email).GetUserCredentials, + positional_args = [server] + if use_oauth2: + positional_args.append( + OAuth2Creds(server, oauth2_port, open_oauth2_local_webbrowser)) + else: + positional_args.append(KeyringCreds(server, host, email).GetUserCredentials) + return HttpRpcServer(*positional_args, host_override=host_override, save_cookies=save_cookies, account_type=account_type) @@ -756,7 +1021,7 @@ def RunShellWithReturnCodeAndStderr(command, print_output=False, Returns: Tuple (stdout, stderr, return code) """ - logging.info("Running %s", command) + LOGGER.info("Running %s", command) env = env.copy() env['LC_MESSAGES'] = 'C' p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, @@ -896,13 +1161,13 @@ class VersionControlSystem(object): else: type = "current" if len(content) > MAX_UPLOAD_SIZE: - print ("Not uploading the %s file for %s because it's too large." % - (type, filename)) + result = ("Not uploading the %s file for %s because it's too large." % + (type, filename)) file_too_large = True content = "" + elif options.verbose: + result = "Uploading %s file for %s" % (type, filename) checksum = md5(content).hexdigest() - if options.verbose > 0 and not file_too_large: - print "Uploading %s file for %s" % (type, filename) url = "/%d/upload_content/%d/%d" % (int(issue), int(patchset), file_id) form_fields = [("filename", filename), ("status", status), @@ -916,14 +1181,24 @@ class VersionControlSystem(object): form_fields.append(("user", options.email)) ctype, body = EncodeMultipartFormData(form_fields, [("data", filename, content)]) - response_body = rpc_server.Send(url, body, - content_type=ctype) + try: + response_body = rpc_server.Send(url, body, content_type=ctype) + except urllib2.HTTPError, e: + response_body = ("Failed to upload file for %s. Got %d status code." % + (filename, e.code)) + if not response_body.startswith("OK"): StatusUpdate(" --> %s" % response_body) sys.exit(1) + return result + patches = dict() [patches.setdefault(v, k) for k, v in patch_list] + + threads = [] + thread_pool = ThreadPool(options.num_upload_threads) + for filename in patches.keys(): base_content, new_content, is_binary, status = files[filename] file_id_str = patches.get(filename) @@ -932,16 +1207,24 @@ class VersionControlSystem(object): file_id_str = file_id_str[file_id_str.rfind("_") + 1:] file_id = int(file_id_str) if base_content != None: - UploadFile(filename, file_id, base_content, is_binary, status, True) + t = thread_pool.apply_async(UploadFile, args=(filename, + file_id, base_content, is_binary, status, True)) + threads.append(t) if new_content != None: - UploadFile(filename, file_id, new_content, is_binary, status, False) + t = thread_pool.apply_async(UploadFile, args=(filename, + file_id, new_content, is_binary, status, False)) + threads.append(t) + + for t in threads: + print t.get(timeout=60) + def IsImage(self, filename): """Returns true if the filename has an image extension.""" mimetype = mimetypes.guess_type(filename)[0] if not mimetype: return False - return mimetype.startswith("image/") + return mimetype.startswith("image/") and not mimetype.startswith("image/svg") def IsBinaryData(self, data): """Returns true if data contains a null byte.""" @@ -1000,7 +1283,7 @@ class SubversionVCS(VersionControlSystem): path = path + "/" base = urlparse.urlunparse((scheme, netloc, path, params, query, fragment)) - logging.info("Guessed %sbase = %s", guess, base) + LOGGER.info("Guessed %sbase = %s", guess, base) return base if required: ErrorExit("Can't find URL in output from svn info") @@ -1019,7 +1302,7 @@ class SubversionVCS(VersionControlSystem): return filename def GenerateDiff(self, args): - cmd = ["svn", "diff"] + cmd = ["svn", "diff", "--internal-diff"] if self.options.revision: cmd += ["-r", self.options.revision] cmd.extend(args) @@ -1028,7 +1311,7 @@ class SubversionVCS(VersionControlSystem): for line in data.splitlines(): if line.startswith("Index:") or line.startswith("Property changes on:"): count += 1 - logging.info(line) + LOGGER.info(line) if not count: ErrorExit("No valid patches found in output from svn diff") return data @@ -1333,16 +1616,18 @@ class GitVCS(VersionControlSystem): # append a diff (with rename detection), without deletes. cmd = [ "git", "diff", "--no-color", "--no-ext-diff", "--full-index", - "--ignore-submodules", "--binary", + "--ignore-submodules", "--src-prefix=a/", "--dst-prefix=b/", ] diff = RunShell( cmd + ["--no-renames", "--diff-filter=D"] + extra_args, env=env, silent_ok=True) + assert 0 <= self.options.git_similarity <= 100 if self.options.git_find_copies: - similarity_options = ["--find-copies-harder", "-l100000", - "-C%s" % self.options.git_similarity ] + similarity_options = ["-l100000", "-C%d%%" % self.options.git_similarity] + if self.options.git_find_copies_harder: + similarity_options.append("--find-copies-harder") else: - similarity_options = ["-M%s" % self.options.git_similarity ] + similarity_options = ["-M%d%%" % self.options.git_similarity ] diff += RunShell( cmd + ["--diff-filter=AMCRT"] + similarity_options + extra_args, env=env, silent_ok=True) @@ -1358,10 +1643,10 @@ class GitVCS(VersionControlSystem): silent_ok=True) return status.splitlines() - def GetFileContent(self, file_hash, is_binary): + def GetFileContent(self, file_hash): """Returns the content of a file identified by its git hash.""" data, retcode = RunShellWithReturnCode(["git", "show", file_hash], - universal_newlines=not is_binary) + universal_newlines=False) if retcode: ErrorExit("Got error status from 'git show %s'" % file_hash) return data @@ -1377,7 +1662,8 @@ class GitVCS(VersionControlSystem): if filename not in self.hashes: # If a rename doesn't change the content, we never get a hash. base_content = RunShell( - ["git", "show", "HEAD:" + filename], silent_ok=True) + ["git", "show", "HEAD:" + filename], silent_ok=True, + universal_newlines=False) elif not hash_before: status = "A" base_content = "" @@ -1386,18 +1672,22 @@ class GitVCS(VersionControlSystem): else: status = "M" - is_image = self.IsImage(filename) - is_binary = is_image or self.IsBinaryData(base_content) - # Grab the before/after content if we need it. # Grab the base content if we don't have it already. if base_content is None and hash_before: - base_content = self.GetFileContent(hash_before, is_binary) + base_content = self.GetFileContent(hash_before) + + is_binary = self.IsImage(filename) + if base_content: + is_binary = is_binary or self.IsBinaryData(base_content) + # Only include the "after" file if it's an image; otherwise it # it is reconstructed from the diff. - if is_image and hash_after: - new_content = self.GetFileContent(hash_after, is_binary) - + if hash_after: + new_content = self.GetFileContent(hash_after) + is_binary = is_binary or self.IsBinaryData(new_content) + if not is_binary: + new_content = None return (base_content, new_content, is_binary, status) @@ -1453,7 +1743,7 @@ class CVSVCS(VersionControlSystem): for line in data.splitlines(): if line.startswith("Index:"): count += 1 - logging.info(line) + LOGGER.info(line) if not count: ErrorExit("No valid patches found in output from cvs diff") @@ -1484,7 +1774,11 @@ class MercurialVCS(VersionControlSystem): if self.options.revision: self.base_rev = self.options.revision else: - self.base_rev = RunShell(["hg", "parent", "-q"]).split(':')[1].strip() + parent = RunShell(["hg", "parent", "-q"], silent_ok=True) + if parent: + self.base_rev = parent.split(':')[1].strip() + else: + self.base_rev = '0' def GetGUID(self): # See chapter "Uniquely identifying a repository" @@ -1515,7 +1809,7 @@ class MercurialVCS(VersionControlSystem): svndiff.append("Index: %s" % filename) svndiff.append("=" * 67) filecount += 1 - logging.info(line) + LOGGER.info(line) else: svndiff.append(line) if not filecount: @@ -1941,26 +2235,48 @@ def UploadSeparatePatches(issue, rpc_server, patchset, data, options): Returns a list of [patch_key, filename] for each file. """ - patches = SplitPatch(data) - rv = [] - for patch in patches: - if len(patch[1]) > MAX_UPLOAD_SIZE: - print ("Not uploading the patch for " + patch[0] + - " because the file is too large.") - continue - form_fields = [("filename", patch[0])] + def UploadFile(filename, data): + form_fields = [("filename", filename)] if not options.download_base: form_fields.append(("content_upload", "1")) - files = [("data", "data.diff", patch[1])] + files = [("data", "data.diff", data)] ctype, body = EncodeMultipartFormData(form_fields, files) url = "/%d/upload_patch/%d" % (int(issue), int(patchset)) - print "Uploading patch for " + patch[0] - response_body = rpc_server.Send(url, body, content_type=ctype) + + try: + response_body = rpc_server.Send(url, body, content_type=ctype) + except urllib2.HTTPError, e: + response_body = ("Failed to upload patch for %s. Got %d status code." % + (filename, e.code)) + lines = response_body.splitlines() if not lines or lines[0] != "OK": StatusUpdate(" --> %s" % response_body) sys.exit(1) - rv.append([lines[1], patch[0]]) + return ("Uploaded patch for " + filename, [lines[1], filename]) + + threads = [] + thread_pool = ThreadPool(options.num_upload_threads) + + patches = SplitPatch(data) + rv = [] + for patch in patches: + if len(patch[1]) > MAX_UPLOAD_SIZE: + print ("Not uploading the patch for " + patch[0] + + " because the file is too large.") + continue + + filename = patch[0] + data = patch[1] + + t = thread_pool.apply_async(UploadFile, args=(filename, data)) + threads.append(t) + + for t in threads: + result = t.get(timeout=60) + print result[0] + rv.append(result[1]) + return rv @@ -2210,7 +2526,11 @@ def RealMain(argv, data=None): if options.help: if options.verbose < 2: # hide Perforce options - parser.epilog = "Use '--help -v' to show additional Perforce options." + parser.epilog = ( + "Use '--help -v' to show additional Perforce options. " + "For more help, see " + "https://github.com/rietveld-codereview/rietveld/wiki" + ) parser.option_groups.remove(parser.get_option_group('--p4_port')) parser.print_help() sys.exit(0) @@ -2218,9 +2538,9 @@ def RealMain(argv, data=None): global verbosity verbosity = options.verbose if verbosity >= 3: - logging.getLogger().setLevel(logging.DEBUG) + LOGGER.setLevel(logging.DEBUG) elif verbosity >= 2: - logging.getLogger().setLevel(logging.INFO) + LOGGER.setLevel(logging.INFO) vcs = GuessVCS(options) @@ -2238,7 +2558,7 @@ def RealMain(argv, data=None): if not base and options.download_base: options.download_base = True - logging.info("Enabled upload of base file") + LOGGER.info("Enabled upload of base file") if not options.assume_yes: vcs.CheckForUnknownFiles() if data is None: @@ -2251,11 +2571,16 @@ def RealMain(argv, data=None): files = vcs.GetBaseFiles(data) if verbosity >= 1: print "Upload server:", options.server, "(change with -s/--server)" + if options.use_oauth2: + options.save_cookies = False rpc_server = GetRpcServer(options.server, options.email, options.host, options.save_cookies, - options.account_type) + options.account_type, + options.use_oauth2, + options.oauth2_port, + options.open_oauth2_local_webbrowser) form_fields = [] repo_guid = vcs.GetGUID() @@ -2265,7 +2590,7 @@ def RealMain(argv, data=None): b = urlparse.urlparse(base) username, netloc = urllib.splituser(b.netloc) if username: - logging.info("Removed username from base URL") + LOGGER.info("Removed username from base URL") base = urlparse.urlunparse((b.scheme, netloc, b.path, b.params, b.query, b.fragment)) form_fields.append(("base", base)) -- 2.1.4