testing_support: switch to 4 space indent

Reformat this dir by itself to help merging with conflicts with other CLs.

Reformatted using:
parallel ./yapf -i -- testing_support/*.py
~/chromiumos/chromite/contrib/reflow_overlong_comments testing_support/*.py

The files that still had strings that were too long were manually
reformatted.
testing_support/coverage_utils.py
testing_support/fake_repos.py
testing_support/git_test_utils.py
testing_support/presubmit_canned_checks_test_mocks.py

Change-Id: I4726a4bbd279a70bcf65d0987fcff0ff9a231386
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/tools/depot_tools/+/4842593
Reviewed-by: Josip Sokcevic <sokcevic@chromium.org>
Commit-Queue: Josip Sokcevic <sokcevic@chromium.org>
Auto-Submit: Mike Frysinger <vapier@chromium.org>
This commit is contained in:
Mike Frysinger
2023-09-05 20:14:34 +00:00
committed by LUCI CQ
parent 691128f836
commit f38dc929a8
10 changed files with 1396 additions and 1310 deletions

View File

@@ -0,0 +1,3 @@
[style]
based_on_style = pep8
column_limit = 80

View File

@@ -10,22 +10,26 @@ import sys
import textwrap import textwrap
import unittest import unittest
ROOT_PATH = os.path.abspath(os.path.join( ROOT_PATH = os.path.abspath(
os.path.dirname(os.path.dirname(__file__)))) os.path.join(os.path.dirname(os.path.dirname(__file__))))
def native_error(msg, version): def native_error(msg, version):
print(textwrap.dedent("""\ print(
textwrap.dedent("""\
ERROR: Native python-coverage (version: %s) is required to be ERROR: Native python-coverage (version: %s) is required to be
installed on your PYTHONPATH to run this test. Recommendation: installed on your PYTHONPATH to run this test. Recommendation:
sudo apt-get install pip sudo apt-get install pip
sudo pip install --upgrade coverage sudo pip install --upgrade coverage
%s""") % (version, msg)) %s""") % (version, msg))
sys.exit(1) sys.exit(1)
def covered_main(includes, require_native=None, required_percentage=100.0,
def covered_main(includes,
require_native=None,
required_percentage=100.0,
disable_coverage=True): disable_coverage=True):
"""Equivalent of unittest.main(), except that it gathers coverage data, and """Equivalent of unittest.main(), except that it gathers coverage data, and
asserts if the test is not at 100% coverage. asserts if the test is not at 100% coverage.
Args: Args:
@@ -37,43 +41,45 @@ def covered_main(includes, require_native=None, required_percentage=100.0,
disable_coverage (bool) - If True, just run unittest.main() without any disable_coverage (bool) - If True, just run unittest.main() without any
coverage tracking. Bug: crbug.com/662277 coverage tracking. Bug: crbug.com/662277
""" """
if disable_coverage: if disable_coverage:
unittest.main() unittest.main()
return return
try: try:
import coverage import coverage
if require_native is not None: if require_native is not None:
got_ver = coverage.__version__ got_ver = coverage.__version__
if not getattr(coverage.collector, 'CTracer', None): if not getattr(coverage.collector, 'CTracer', None):
native_error(( native_error(
"Native python-coverage module required.\n" ("Native python-coverage module required.\n"
"Pure-python implementation (version: %s) found: %s" "Pure-python implementation (version: %s) found: %s") %
) % (got_ver, coverage), require_native) (got_ver, coverage), require_native)
if got_ver < distutils.version.LooseVersion(require_native): if got_ver < distutils.version.LooseVersion(require_native):
native_error("Wrong version (%s) found: %s" % (got_ver, coverage), native_error(
require_native) "Wrong version (%s) found: %s" % (got_ver, coverage),
except ImportError: require_native)
if require_native is None: except ImportError:
sys.path.insert(0, os.path.join(ROOT_PATH, 'third_party')) if require_native is None:
import coverage sys.path.insert(0, os.path.join(ROOT_PATH, 'third_party'))
else: import coverage
print("ERROR: python-coverage (%s) is required to be installed on your " else:
"PYTHONPATH to run this test." % require_native) print(
sys.exit(1) "ERROR: python-coverage (%s) is required to be installed on "
"your PYTHONPATH to run this test." % require_native)
sys.exit(1)
COVERAGE = coverage.coverage(include=includes) COVERAGE = coverage.coverage(include=includes)
COVERAGE.start() COVERAGE.start()
retcode = 0 retcode = 0
try: try:
unittest.main() unittest.main()
except SystemExit as e: except SystemExit as e:
retcode = e.code or retcode retcode = e.code or retcode
COVERAGE.stop() COVERAGE.stop()
if COVERAGE.report() < required_percentage: if COVERAGE.report() < required_percentage:
print('FATAL: not at required %f%% coverage.' % required_percentage) print('FATAL: not at required %f%% coverage.' % required_percentage)
retcode = 2 retcode = 2
return retcode return retcode

View File

@@ -74,129 +74,131 @@ DESCRIBE_JSON_TEMPLATE = """{
def parse_cipd(root, contents): def parse_cipd(root, contents):
tree = {} tree = {}
current_subdir = None current_subdir = None
for line in contents: for line in contents:
line = line.strip() line = line.strip()
match = re.match(CIPD_SUBDIR_RE, line) match = re.match(CIPD_SUBDIR_RE, line)
if match: if match:
print('match') print('match')
current_subdir = os.path.join(root, *match.group(1).split('/')) current_subdir = os.path.join(root, *match.group(1).split('/'))
if not root: if not root:
current_subdir = match.group(1) current_subdir = match.group(1)
elif line and current_subdir: elif line and current_subdir:
print('no match') print('no match')
tree.setdefault(current_subdir, []).append(line) tree.setdefault(current_subdir, []).append(line)
return tree return tree
def expand_package_name_cmd(package_name): def expand_package_name_cmd(package_name):
package_split = package_name.split("/") package_split = package_name.split("/")
suffix = package_split[-1] suffix = package_split[-1]
# Any use of var equality should return empty for testing. # Any use of var equality should return empty for testing.
if "=" in suffix: if "=" in suffix:
if suffix != "${platform=fake-platform-ok}": if suffix != "${platform=fake-platform-ok}":
return "" return ""
package_name = "/".join(package_split[:-1] + ["${platform}"]) package_name = "/".join(package_split[:-1] + ["${platform}"])
for v in [ARCH_VAR, OS_VAR, PLATFORM_VAR]: for v in [ARCH_VAR, OS_VAR, PLATFORM_VAR]:
var = "${%s}" % v var = "${%s}" % v
if package_name.endswith(var): if package_name.endswith(var):
package_name = package_name.replace(var, "%s-expanded-test-only" % v) package_name = package_name.replace(var,
return package_name "%s-expanded-test-only" % v)
return package_name
def ensure_file_resolve(): def ensure_file_resolve():
resolved = {"result": {}} resolved = {"result": {}}
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-ensure-file', required=True) parser.add_argument('-ensure-file', required=True)
parser.add_argument('-json-output') parser.add_argument('-json-output')
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
with io.open(args.ensure_file, 'r', encoding='utf-8') as f: with io.open(args.ensure_file, 'r', encoding='utf-8') as f:
new_content = parse_cipd("", f.readlines()) new_content = parse_cipd("", f.readlines())
for path, packages in new_content.items(): for path, packages in new_content.items():
resolved_packages = [] resolved_packages = []
for package in packages: for package in packages:
package_name = expand_package_name_cmd(package.split(" ")[0]) package_name = expand_package_name_cmd(package.split(" ")[0])
resolved_packages.append({ resolved_packages.append({
"package": package_name, "package": package_name,
"pin": { "pin": {
"package": package_name, "package": package_name,
"instance_id": package_name + "-fake-resolved-id", "instance_id": package_name + "-fake-resolved-id",
} }
}) })
resolved["result"][path] = resolved_packages resolved["result"][path] = resolved_packages
with io.open(args.json_output, 'w', encoding='utf-8') as f: with io.open(args.json_output, 'w', encoding='utf-8') as f:
f.write(json.dumps(resolved, indent=4)) f.write(json.dumps(resolved, indent=4))
def describe_cmd(package_name): def describe_cmd(package_name):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-json-output') parser.add_argument('-json-output')
parser.add_argument('-version', required=True) parser.add_argument('-version', required=True)
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
json_template = Template(DESCRIBE_JSON_TEMPLATE).substitute( json_template = Template(DESCRIBE_JSON_TEMPLATE).substitute(
package=package_name) package=package_name)
cli_out = Template(DESCRIBE_STDOUT_TEMPLATE).substitute(package=package_name) cli_out = Template(DESCRIBE_STDOUT_TEMPLATE).substitute(
json_out = json.loads(json_template) package=package_name)
found = False json_out = json.loads(json_template)
for tag in json_out['result']['tags']: found = False
if tag['tag'] == args.version: for tag in json_out['result']['tags']:
found = True if tag['tag'] == args.version:
break found = True
for tag in json_out['result']['refs']: break
if tag['ref'] == args.version: for tag in json_out['result']['refs']:
found = True if tag['ref'] == args.version:
break found = True
if found: break
if args.json_output: if found:
with io.open(args.json_output, 'w', encoding='utf-8') as f: if args.json_output:
f.write(json.dumps(json_out, indent=4)) with io.open(args.json_output, 'w', encoding='utf-8') as f:
sys.stdout.write(cli_out) f.write(json.dumps(json_out, indent=4))
return 0 sys.stdout.write(cli_out)
sys.stdout.write('Error: no such ref.\n') return 0
return 1 sys.stdout.write('Error: no such ref.\n')
return 1
def main(): def main():
cmd = sys.argv[1] cmd = sys.argv[1]
assert cmd in [ assert cmd in [
CIPD_DESCRIBE, CIPD_ENSURE, CIPD_ENSURE_FILE_RESOLVE, CIPD_EXPAND_PKG, CIPD_DESCRIBE, CIPD_ENSURE, CIPD_ENSURE_FILE_RESOLVE, CIPD_EXPAND_PKG,
CIPD_EXPORT CIPD_EXPORT
] ]
# Handle cipd expand-package-name # Handle cipd expand-package-name
if cmd == CIPD_EXPAND_PKG: if cmd == CIPD_EXPAND_PKG:
# Expecting argument after cmd # Expecting argument after cmd
assert len(sys.argv) == 3 assert len(sys.argv) == 3
# Write result to stdout # Write result to stdout
sys.stdout.write(expand_package_name_cmd(sys.argv[2])) sys.stdout.write(expand_package_name_cmd(sys.argv[2]))
return 0
if cmd == CIPD_DESCRIBE:
# Expecting argument after cmd
assert len(sys.argv) >= 3
return describe_cmd(sys.argv[2])
if cmd == CIPD_ENSURE_FILE_RESOLVE:
return ensure_file_resolve()
parser = argparse.ArgumentParser()
parser.add_argument('-ensure-file')
parser.add_argument('-root')
args, _ = parser.parse_known_args()
with io.open(args.ensure_file, 'r', encoding='utf-8') as f:
new_content = parse_cipd(args.root, f.readlines())
# Install new packages
for path, packages in new_content.items():
if not os.path.exists(path):
os.makedirs(path)
with io.open(os.path.join(path, '_cipd'), 'w', encoding='utf-8') as f:
f.write('\n'.join(packages))
# Save the ensure file that we got
shutil.copy(args.ensure_file, os.path.join(args.root, '_cipd'))
return 0 return 0
if cmd == CIPD_DESCRIBE:
# Expecting argument after cmd
assert len(sys.argv) >= 3
return describe_cmd(sys.argv[2])
if cmd == CIPD_ENSURE_FILE_RESOLVE:
return ensure_file_resolve()
parser = argparse.ArgumentParser()
parser.add_argument('-ensure-file')
parser.add_argument('-root')
args, _ = parser.parse_known_args()
with io.open(args.ensure_file, 'r', encoding='utf-8') as f:
new_content = parse_cipd(args.root, f.readlines())
# Install new packages
for path, packages in new_content.items():
if not os.path.exists(path):
os.makedirs(path)
with io.open(os.path.join(path, '_cipd'), 'w', encoding='utf-8') as f:
f.write('\n'.join(packages))
# Save the ensure file that we got
shutil.copy(args.ensure_file, os.path.join(args.root, '_cipd'))
return 0
if __name__ == '__main__': if __name__ == '__main__':
sys.exit(main()) sys.exit(main())

File diff suppressed because it is too large Load Diff

View File

@@ -10,98 +10,98 @@ from io import StringIO
def _RaiseNotFound(path): def _RaiseNotFound(path):
raise IOError(errno.ENOENT, path, os.strerror(errno.ENOENT)) raise IOError(errno.ENOENT, path, os.strerror(errno.ENOENT))
class MockFileSystem(object): class MockFileSystem(object):
"""Stripped-down version of WebKit's webkitpy.common.system.filesystem_mock """Stripped-down version of WebKit's webkitpy.common.system.filesystem_mock
Implements a filesystem-like interface on top of a dict of filenames -> Implements a filesystem-like interface on top of a dict of filenames ->
file contents. A file content value of None indicates that the file should file contents. A file content value of None indicates that the file should
not exist (IOError will be raised if it is opened; not exist (IOError will be raised if it is opened;
reading from a missing key raises a KeyError, not an IOError.""" reading from a missing key raises a KeyError, not an IOError."""
def __init__(self, files=None):
self.files = files or {}
self.written_files = {}
self._sep = '/'
def __init__(self, files=None): @property
self.files = files or {} def sep(self):
self.written_files = {} return self._sep
self._sep = '/'
@property def abspath(self, path):
def sep(self): if path.endswith(self.sep):
return self._sep return path[:-1]
return path
def abspath(self, path): def basename(self, path):
if path.endswith(self.sep): if self.sep not in path:
return path[:-1] return ''
return path return self.split(path)[-1] or self.sep
def basename(self, path): def dirname(self, path):
if self.sep not in path: if self.sep not in path:
return '' return ''
return self.split(path)[-1] or self.sep return self.split(path)[0] or self.sep
def dirname(self, path): def exists(self, path):
if self.sep not in path: return self.isfile(path) or self.isdir(path)
return ''
return self.split(path)[0] or self.sep
def exists(self, path): def isabs(self, path):
return self.isfile(path) or self.isdir(path) return path.startswith(self.sep)
def isabs(self, path): def isfile(self, path):
return path.startswith(self.sep) return path in self.files and self.files[path] is not None
def isfile(self, path): def isdir(self, path):
return path in self.files and self.files[path] is not None if path in self.files:
return False
if not path.endswith(self.sep):
path += self.sep
def isdir(self, path): # We need to use a copy of the keys here in order to avoid switching
if path in self.files: # to a different thread and potentially modifying the dict in
return False # mid-iteration.
if not path.endswith(self.sep): files = list(self.files.keys())[:]
path += self.sep return any(f.startswith(path) for f in files)
# We need to use a copy of the keys here in order to avoid switching def join(self, *comps):
# to a different thread and potentially modifying the dict in # TODO: Might want tests for this and/or a better comment about how
# mid-iteration. # it works.
files = list(self.files.keys())[:] return re.sub(re.escape(os.path.sep), self.sep, os.path.join(*comps))
return any(f.startswith(path) for f in files)
def join(self, *comps): def glob(self, path):
# TODO: Might want tests for this and/or a better comment about how return fnmatch.filter(self.files.keys(), path)
# it works.
return re.sub(re.escape(os.path.sep), self.sep, os.path.join(*comps))
def glob(self, path): def open_for_reading(self, path):
return fnmatch.filter(self.files.keys(), path) return StringIO(self.read_binary_file(path))
def open_for_reading(self, path): def normpath(self, path):
return StringIO(self.read_binary_file(path)) # This is not a complete implementation of normpath. Only covers what we
# use in tests.
result = []
for part in path.split(self.sep):
if part == '..':
result.pop()
elif part == '.':
continue
else:
result.append(part)
return self.sep.join(result)
def normpath(self, path): def read_binary_file(self, path):
# This is not a complete implementation of normpath. Only covers what we # Intentionally raises KeyError if we don't recognize the path.
# use in tests. if self.files[path] is None:
result = [] _RaiseNotFound(path)
for part in path.split(self.sep): return self.files[path]
if part == '..':
result.pop()
elif part == '.':
continue
else:
result.append(part)
return self.sep.join(result)
def read_binary_file(self, path): def relpath(self, path, base):
# Intentionally raises KeyError if we don't recognize the path. # This implementation is wrong in many ways; assert to check them for
if self.files[path] is None: # now.
_RaiseNotFound(path) if not base.endswith(self.sep):
return self.files[path] base += self.sep
assert path.startswith(base)
return path[len(base):]
def relpath(self, path, base): def split(self, path):
# This implementation is wrong in many ways; assert to check them for now. return path.rsplit(self.sep, 1)
if not base.endswith(self.sep):
base += self.sep
assert path.startswith(base)
return path[len(base):]
def split(self, path):
return path.rsplit(self.sep, 1)

View File

@@ -17,108 +17,107 @@ import unittest
import gclient_utils import gclient_utils
DEFAULT_BRANCH = 'main' DEFAULT_BRANCH = 'main'
def git_hash_data(data, typ='blob'): def git_hash_data(data, typ='blob'):
"""Calculate the git-style SHA1 for some data. """Calculate the git-style SHA1 for some data.
Only supports 'blob' type data at the moment. Only supports 'blob' type data at the moment.
""" """
assert typ == 'blob', 'Only support blobs for now' assert typ == 'blob', 'Only support blobs for now'
return hashlib.sha1(b'blob %d\0%s' % (len(data), data)).hexdigest() return hashlib.sha1(b'blob %d\0%s' % (len(data), data)).hexdigest()
class OrderedSet(collections.MutableSet): class OrderedSet(collections.MutableSet):
# from http://code.activestate.com/recipes/576694/ # from http://code.activestate.com/recipes/576694/
def __init__(self, iterable=None): def __init__(self, iterable=None):
self.end = end = [] self.end = end = []
end += [None, end, end] # sentinel node for doubly linked list end += [None, end, end] # sentinel node for doubly linked list
self.data = {} # key --> [key, prev, next] self.data = {} # key --> [key, prev, next]
if iterable is not None: if iterable is not None:
self |= iterable self |= iterable
def __contains__(self, key): def __contains__(self, key):
return key in self.data return key in self.data
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, OrderedSet): if isinstance(other, OrderedSet):
return len(self) == len(other) and list(self) == list(other) return len(self) == len(other) and list(self) == list(other)
return set(self) == set(other) return set(self) == set(other)
def __ne__(self, other): def __ne__(self, other):
if isinstance(other, OrderedSet): if isinstance(other, OrderedSet):
return len(self) != len(other) or list(self) != list(other) return len(self) != len(other) or list(self) != list(other)
return set(self) != set(other) return set(self) != set(other)
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
def __iter__(self): def __iter__(self):
end = self.end end = self.end
curr = end[2] curr = end[2]
while curr is not end: while curr is not end:
yield curr[0] yield curr[0]
curr = curr[2] curr = curr[2]
def __repr__(self): def __repr__(self):
if not self: if not self:
return '%s()' % (self.__class__.__name__,) return '%s()' % (self.__class__.__name__, )
return '%s(%r)' % (self.__class__.__name__, list(self)) return '%s(%r)' % (self.__class__.__name__, list(self))
def __reversed__(self): def __reversed__(self):
end = self.end end = self.end
curr = end[1] curr = end[1]
while curr is not end: while curr is not end:
yield curr[0] yield curr[0]
curr = curr[1] curr = curr[1]
def add(self, key): def add(self, key):
if key not in self.data: if key not in self.data:
end = self.end end = self.end
curr = end[1] curr = end[1]
curr[2] = end[1] = self.data[key] = [key, curr, end] curr[2] = end[1] = self.data[key] = [key, curr, end]
def difference_update(self, *others): def difference_update(self, *others):
for other in others: for other in others:
for i in other: for i in other:
self.discard(i) self.discard(i)
def discard(self, key): def discard(self, key):
if key in self.data: if key in self.data:
key, prev, nxt = self.data.pop(key) key, prev, nxt = self.data.pop(key)
prev[2] = nxt prev[2] = nxt
nxt[1] = prev nxt[1] = prev
def pop(self, last=True): # pylint: disable=arguments-differ def pop(self, last=True): # pylint: disable=arguments-differ
if not self: if not self:
raise KeyError('set is empty') raise KeyError('set is empty')
key = self.end[1][0] if last else self.end[2][0] key = self.end[1][0] if last else self.end[2][0]
self.discard(key) self.discard(key)
return key return key
class UTC(datetime.tzinfo): class UTC(datetime.tzinfo):
"""UTC time zone. """UTC time zone.
from https://docs.python.org/2/library/datetime.html#tzinfo-objects from https://docs.python.org/2/library/datetime.html#tzinfo-objects
""" """
def utcoffset(self, dt): def utcoffset(self, dt):
return datetime.timedelta(0) return datetime.timedelta(0)
def tzname(self, dt): def tzname(self, dt):
return "UTC" return "UTC"
def dst(self, dt): def dst(self, dt):
return datetime.timedelta(0) return datetime.timedelta(0)
UTC = UTC() UTC = UTC()
class GitRepoSchema(object): class GitRepoSchema(object):
"""A declarative git testing repo. """A declarative git testing repo.
Pass a schema to __init__ in the form of: Pass a schema to __init__ in the form of:
A B C D A B C D
@@ -141,11 +140,10 @@ class GitRepoSchema(object):
in the schema) get earlier timestamps. Stamps start at the Unix Epoch, and in the schema) get earlier timestamps. Stamps start at the Unix Epoch, and
increment by 1 day each. increment by 1 day each.
""" """
COMMIT = collections.namedtuple('COMMIT', 'name parents is_branch is_root') COMMIT = collections.namedtuple('COMMIT', 'name parents is_branch is_root')
def __init__(self, repo_schema='', def __init__(self, repo_schema='', content_fn=lambda v: {v: {'data': v}}):
content_fn=lambda v: {v: {'data': v}}): """Builds a new GitRepoSchema.
"""Builds a new GitRepoSchema.
Args: Args:
repo_schema (str) - Initial schema for this repo. See class docstring for repo_schema (str) - Initial schema for this repo. See class docstring for
@@ -156,88 +154,88 @@ class GitRepoSchema(object):
commit_name). See the docstring on the GitRepo class for the format of commit_name). See the docstring on the GitRepo class for the format of
the data returned by this function. the data returned by this function.
""" """
self.main = None self.main = None
self.par_map = {} self.par_map = {}
self.data_cache = {} self.data_cache = {}
self.content_fn = content_fn self.content_fn = content_fn
self.add_commits(repo_schema) self.add_commits(repo_schema)
def walk(self): def walk(self):
"""(Generator) Walks the repo schema from roots to tips. """(Generator) Walks the repo schema from roots to tips.
Generates GitRepoSchema.COMMIT objects for each commit. Generates GitRepoSchema.COMMIT objects for each commit.
Throws an AssertionError if it detects a cycle. Throws an AssertionError if it detects a cycle.
""" """
is_root = True is_root = True
par_map = copy.deepcopy(self.par_map) par_map = copy.deepcopy(self.par_map)
while par_map: while par_map:
empty_keys = set(k for k, v in par_map.items() if not v) empty_keys = set(k for k, v in par_map.items() if not v)
assert empty_keys, 'Cycle detected! %s' % par_map assert empty_keys, 'Cycle detected! %s' % par_map
for k in sorted(empty_keys): for k in sorted(empty_keys):
yield self.COMMIT(k, self.par_map[k], yield self.COMMIT(
not any(k in v for v in self.par_map.values()), k, self.par_map[k],
is_root) not any(k in v for v in self.par_map.values()), is_root)
del par_map[k] del par_map[k]
for v in par_map.values(): for v in par_map.values():
v.difference_update(empty_keys) v.difference_update(empty_keys)
is_root = False is_root = False
def add_partial(self, commit, parent=None): def add_partial(self, commit, parent=None):
if commit not in self.par_map: if commit not in self.par_map:
self.par_map[commit] = OrderedSet() self.par_map[commit] = OrderedSet()
if parent is not None: if parent is not None:
self.par_map[commit].add(parent) self.par_map[commit].add(parent)
def add_commits(self, schema): def add_commits(self, schema):
"""Adds more commits from a schema into the existing Schema. """Adds more commits from a schema into the existing Schema.
Args: Args:
schema (str) - See class docstring for info on schema format. schema (str) - See class docstring for info on schema format.
Throws an AssertionError if it detects a cycle. Throws an AssertionError if it detects a cycle.
""" """
for commits in (l.split() for l in schema.splitlines() if l.strip()): for commits in (l.split() for l in schema.splitlines() if l.strip()):
parent = None parent = None
for commit in commits: for commit in commits:
self.add_partial(commit, parent) self.add_partial(commit, parent)
parent = commit parent = commit
if parent and not self.main: if parent and not self.main:
self.main = parent self.main = parent
for _ in self.walk(): # This will throw if there are any cycles. for _ in self.walk(): # This will throw if there are any cycles.
pass pass
def reify(self): def reify(self):
"""Returns a real GitRepo for this GitRepoSchema""" """Returns a real GitRepo for this GitRepoSchema"""
return GitRepo(self) return GitRepo(self)
def data_for(self, commit): def data_for(self, commit):
"""Obtains the data for |commit|. """Obtains the data for |commit|.
See the docstring on the GitRepo class for the format of the returned data. See the docstring on the GitRepo class for the format of the returned data.
Caches the result on this GitRepoSchema instance. Caches the result on this GitRepoSchema instance.
""" """
if commit not in self.data_cache: if commit not in self.data_cache:
self.data_cache[commit] = self.content_fn(commit) self.data_cache[commit] = self.content_fn(commit)
return self.data_cache[commit] return self.data_cache[commit]
def simple_graph(self): def simple_graph(self):
"""Returns a dictionary of {commit_subject: {parent commit_subjects}} """Returns a dictionary of {commit_subject: {parent commit_subjects}}
This allows you to get a very simple connection graph over the whole repo This allows you to get a very simple connection graph over the whole repo
for comparison purposes. Only commit subjects (not ids, not content/data) for comparison purposes. Only commit subjects (not ids, not content/data)
are considered are considered
""" """
ret = {} ret = {}
for commit in self.walk(): for commit in self.walk():
ret.setdefault(commit.name, set()).update(commit.parents) ret.setdefault(commit.name, set()).update(commit.parents)
return ret return ret
class GitRepo(object): class GitRepo(object):
"""Creates a real git repo for a GitRepoSchema. """Creates a real git repo for a GitRepoSchema.
Obtains schema and content information from the GitRepoSchema. Obtains schema and content information from the GitRepoSchema.
@@ -260,26 +258,26 @@ class GitRepo(object):
For file content, if 'data' is None, then this commit will `git rm` that file. For file content, if 'data' is None, then this commit will `git rm` that file.
""" """
BASE_TEMP_DIR = tempfile.mkdtemp(suffix='base', prefix='git_repo') BASE_TEMP_DIR = tempfile.mkdtemp(suffix='base', prefix='git_repo')
atexit.register(gclient_utils.rmtree, BASE_TEMP_DIR) atexit.register(gclient_utils.rmtree, BASE_TEMP_DIR)
# Singleton objects to specify specific data in a commit dictionary. # Singleton objects to specify specific data in a commit dictionary.
AUTHOR_NAME = object() AUTHOR_NAME = object()
AUTHOR_EMAIL = object() AUTHOR_EMAIL = object()
AUTHOR_DATE = object() AUTHOR_DATE = object()
COMMITTER_NAME = object() COMMITTER_NAME = object()
COMMITTER_EMAIL = object() COMMITTER_EMAIL = object()
COMMITTER_DATE = object() COMMITTER_DATE = object()
DEFAULT_AUTHOR_NAME = 'Author McAuthorly' DEFAULT_AUTHOR_NAME = 'Author McAuthorly'
DEFAULT_AUTHOR_EMAIL = 'author@example.com' DEFAULT_AUTHOR_EMAIL = 'author@example.com'
DEFAULT_COMMITTER_NAME = 'Charles Committish' DEFAULT_COMMITTER_NAME = 'Charles Committish'
DEFAULT_COMMITTER_EMAIL = 'commitish@example.com' DEFAULT_COMMITTER_EMAIL = 'commitish@example.com'
COMMAND_OUTPUT = collections.namedtuple('COMMAND_OUTPUT', 'retcode stdout') COMMAND_OUTPUT = collections.namedtuple('COMMAND_OUTPUT', 'retcode stdout')
def __init__(self, schema): def __init__(self, schema):
"""Makes new GitRepo. """Makes new GitRepo.
Automatically creates a temp folder under GitRepo.BASE_TEMP_DIR. It's Automatically creates a temp folder under GitRepo.BASE_TEMP_DIR. It's
recommended that you clean this repo up by calling nuke() on it, but if not, recommended that you clean this repo up by calling nuke() on it, but if not,
@@ -289,194 +287,198 @@ class GitRepo(object):
Args: Args:
schema - An instance of GitRepoSchema schema - An instance of GitRepoSchema
""" """
self.repo_path = os.path.realpath(tempfile.mkdtemp(dir=self.BASE_TEMP_DIR)) self.repo_path = os.path.realpath(
self.commit_map = {} tempfile.mkdtemp(dir=self.BASE_TEMP_DIR))
self._date = datetime.datetime(1970, 1, 1, tzinfo=UTC) self.commit_map = {}
self._date = datetime.datetime(1970, 1, 1, tzinfo=UTC)
self.to_schema_refs = ['--branches'] self.to_schema_refs = ['--branches']
self.git('init', '-b', DEFAULT_BRANCH) self.git('init', '-b', DEFAULT_BRANCH)
self.git('config', 'user.name', 'testcase') self.git('config', 'user.name', 'testcase')
self.git('config', 'user.email', 'testcase@example.com') self.git('config', 'user.email', 'testcase@example.com')
for commit in schema.walk(): for commit in schema.walk():
self._add_schema_commit(commit, schema.data_for(commit.name)) self._add_schema_commit(commit, schema.data_for(commit.name))
self.last_commit = self[commit.name] self.last_commit = self[commit.name]
if schema.main: if schema.main:
self.git('update-ref', 'refs/heads/main', self[schema.main]) self.git('update-ref', 'refs/heads/main', self[schema.main])
def __getitem__(self, commit_name): def __getitem__(self, commit_name):
"""Gets the hash of a commit by its schema name. """Gets the hash of a commit by its schema name.
>>> r = GitRepo(GitRepoSchema('A B C')) >>> r = GitRepo(GitRepoSchema('A B C'))
>>> r['B'] >>> r['B']
'7381febe1da03b09da47f009963ab7998a974935' '7381febe1da03b09da47f009963ab7998a974935'
""" """
return self.commit_map[commit_name] return self.commit_map[commit_name]
def _add_schema_commit(self, commit, commit_data): def _add_schema_commit(self, commit, commit_data):
commit_data = commit_data or {} commit_data = commit_data or {}
if commit.parents: if commit.parents:
parents = list(commit.parents) parents = list(commit.parents)
self.git('checkout', '--detach', '-q', self[parents[0]]) self.git('checkout', '--detach', '-q', self[parents[0]])
if len(parents) > 1: if len(parents) > 1:
self.git('merge', '--no-commit', '-q', *[self[x] for x in parents[1:]]) self.git('merge', '--no-commit', '-q',
else: *[self[x] for x in parents[1:]])
self.git('checkout', '--orphan', 'root_%s' % commit.name)
self.git('rm', '-rf', '.')
env = self.get_git_commit_env(commit_data)
for fname, file_data in commit_data.items():
# If it isn't a string, it's one of the special keys.
if not isinstance(fname, str):
continue
deleted = False
if 'data' in file_data:
data = file_data.get('data')
if data is None:
deleted = True
self.git('rm', fname)
else: else:
path = os.path.join(self.repo_path, fname) self.git('checkout', '--orphan', 'root_%s' % commit.name)
pardir = os.path.dirname(path) self.git('rm', '-rf', '.')
if not os.path.exists(pardir):
os.makedirs(pardir)
with open(path, 'wb') as f:
f.write(data)
mode = file_data.get('mode') env = self.get_git_commit_env(commit_data)
if mode and not deleted:
os.chmod(path, mode)
self.git('add', fname) for fname, file_data in commit_data.items():
# If it isn't a string, it's one of the special keys.
if not isinstance(fname, str):
continue
rslt = self.git('commit', '--allow-empty', '-m', commit.name, env=env) deleted = False
assert rslt.retcode == 0, 'Failed to commit %s' % str(commit) if 'data' in file_data:
self.commit_map[commit.name] = self.git('rev-parse', 'HEAD').stdout.strip() data = file_data.get('data')
self.git('tag', 'tag_%s' % commit.name, self[commit.name]) if data is None:
if commit.is_branch: deleted = True
self.git('branch', '-f', 'branch_%s' % commit.name, self[commit.name]) self.git('rm', fname)
else:
path = os.path.join(self.repo_path, fname)
pardir = os.path.dirname(path)
if not os.path.exists(pardir):
os.makedirs(pardir)
with open(path, 'wb') as f:
f.write(data)
def get_git_commit_env(self, commit_data=None): mode = file_data.get('mode')
commit_data = commit_data or {} if mode and not deleted:
env = os.environ.copy() os.chmod(path, mode)
for prefix in ('AUTHOR', 'COMMITTER'):
for suffix in ('NAME', 'EMAIL', 'DATE'):
singleton = '%s_%s' % (prefix, suffix)
key = getattr(self, singleton)
if key in commit_data:
val = commit_data[key]
elif suffix == 'DATE':
val = self._date
self._date += datetime.timedelta(days=1)
else:
val = getattr(self, 'DEFAULT_%s' % singleton)
if not isinstance(val, str) and not isinstance(val, bytes):
val = str(val)
env['GIT_%s' % singleton] = val
return env
def git(self, *args, **kwargs): self.git('add', fname)
"""Runs a git command specified by |args| in this repo."""
assert self.repo_path is not None
try:
with open(os.devnull, 'wb') as devnull:
shell = sys.platform == 'win32'
output = subprocess.check_output(
('git', ) + args,
shell=shell,
cwd=self.repo_path,
stderr=devnull,
**kwargs)
output = output.decode('utf-8')
return self.COMMAND_OUTPUT(0, output)
except subprocess.CalledProcessError as e:
return self.COMMAND_OUTPUT(e.returncode, e.output)
def show_commit(self, commit_name, format_string): rslt = self.git('commit', '--allow-empty', '-m', commit.name, env=env)
"""Shows a commit (by its schema name) with a given format string.""" assert rslt.retcode == 0, 'Failed to commit %s' % str(commit)
return self.git('show', '-q', '--pretty=format:%s' % format_string, self.commit_map[commit.name] = self.git('rev-parse',
self[commit_name]).stdout 'HEAD').stdout.strip()
self.git('tag', 'tag_%s' % commit.name, self[commit.name])
if commit.is_branch:
self.git('branch', '-f', 'branch_%s' % commit.name,
self[commit.name])
def git_commit(self, message): def get_git_commit_env(self, commit_data=None):
return self.git('commit', '-am', message, env=self.get_git_commit_env()) commit_data = commit_data or {}
env = os.environ.copy()
for prefix in ('AUTHOR', 'COMMITTER'):
for suffix in ('NAME', 'EMAIL', 'DATE'):
singleton = '%s_%s' % (prefix, suffix)
key = getattr(self, singleton)
if key in commit_data:
val = commit_data[key]
elif suffix == 'DATE':
val = self._date
self._date += datetime.timedelta(days=1)
else:
val = getattr(self, 'DEFAULT_%s' % singleton)
if not isinstance(val, str) and not isinstance(val, bytes):
val = str(val)
env['GIT_%s' % singleton] = val
return env
def nuke(self): def git(self, *args, **kwargs):
"""Obliterates the git repo on disk. """Runs a git command specified by |args| in this repo."""
assert self.repo_path is not None
try:
with open(os.devnull, 'wb') as devnull:
shell = sys.platform == 'win32'
output = subprocess.check_output(('git', ) + args,
shell=shell,
cwd=self.repo_path,
stderr=devnull,
**kwargs)
output = output.decode('utf-8')
return self.COMMAND_OUTPUT(0, output)
except subprocess.CalledProcessError as e:
return self.COMMAND_OUTPUT(e.returncode, e.output)
def show_commit(self, commit_name, format_string):
"""Shows a commit (by its schema name) with a given format string."""
return self.git('show', '-q', '--pretty=format:%s' % format_string,
self[commit_name]).stdout
def git_commit(self, message):
return self.git('commit', '-am', message, env=self.get_git_commit_env())
def nuke(self):
"""Obliterates the git repo on disk.
Causes this GitRepo to be unusable. Causes this GitRepo to be unusable.
""" """
gclient_utils.rmtree(self.repo_path) gclient_utils.rmtree(self.repo_path)
self.repo_path = None self.repo_path = None
def run(self, fn, *args, **kwargs): def run(self, fn, *args, **kwargs):
"""Run a python function with the given args and kwargs with the cwd set to """Run a python function with the given args and kwargs with the cwd
the git repo.""" set to the git repo."""
assert self.repo_path is not None assert self.repo_path is not None
curdir = os.getcwd() curdir = os.getcwd()
try: try:
os.chdir(self.repo_path) os.chdir(self.repo_path)
return fn(*args, **kwargs) return fn(*args, **kwargs)
finally: finally:
os.chdir(curdir) os.chdir(curdir)
def capture_stdio(self, fn, *args, **kwargs): def capture_stdio(self, fn, *args, **kwargs):
"""Run a python function with the given args and kwargs with the cwd set to """Run a python function with the given args and kwargs with the cwd set
the git repo. to the git repo.
Returns the (stdout, stderr) of whatever ran, instead of the what |fn| Returns the (stdout, stderr) of whatever ran, instead of the what |fn|
returned. returned.
""" """
stdout = sys.stdout stdout = sys.stdout
stderr = sys.stderr stderr = sys.stderr
try: try:
with tempfile.TemporaryFile('w+') as out: with tempfile.TemporaryFile('w+') as out:
with tempfile.TemporaryFile('w+') as err: with tempfile.TemporaryFile('w+') as err:
sys.stdout = out sys.stdout = out
sys.stderr = err sys.stderr = err
try: try:
self.run(fn, *args, **kwargs) self.run(fn, *args, **kwargs)
except SystemExit: except SystemExit:
pass pass
out.seek(0) out.seek(0)
err.seek(0) err.seek(0)
return out.read(), err.read() return out.read(), err.read()
finally: finally:
sys.stdout = stdout sys.stdout = stdout
sys.stderr = stderr sys.stderr = stderr
def open(self, path, mode='rb'): def open(self, path, mode='rb'):
return open(os.path.join(self.repo_path, path), mode) return open(os.path.join(self.repo_path, path), mode)
def to_schema(self): def to_schema(self):
lines = self.git('rev-list', '--parents', '--reverse', '--topo-order', lines = self.git('rev-list', '--parents', '--reverse', '--topo-order',
'--format=%s', *self.to_schema_refs).stdout.splitlines() '--format=%s',
hash_to_msg = {} *self.to_schema_refs).stdout.splitlines()
ret = GitRepoSchema() hash_to_msg = {}
current = None ret = GitRepoSchema()
parents = []
for line in lines:
if line.startswith('commit'):
assert current is None
tokens = line.split()
current, parents = tokens[1], tokens[2:]
assert all(p in hash_to_msg for p in parents)
else:
assert current is not None
hash_to_msg[current] = line
ret.add_partial(line)
for parent in parents:
ret.add_partial(line, hash_to_msg[parent])
current = None current = None
parents = [] parents = []
assert current is None for line in lines:
return ret if line.startswith('commit'):
assert current is None
tokens = line.split()
current, parents = tokens[1], tokens[2:]
assert all(p in hash_to_msg for p in parents)
else:
assert current is not None
hash_to_msg[current] = line
ret.add_partial(line)
for parent in parents:
ret.add_partial(line, hash_to_msg[parent])
current = None
parents = []
assert current is None
return ret
class GitRepoSchemaTestBase(unittest.TestCase): class GitRepoSchemaTestBase(unittest.TestCase):
"""A TestCase with a built-in GitRepoSchema. """A TestCase with a built-in GitRepoSchema.
Expects a class variable REPO_SCHEMA to be a GitRepoSchema string in the form Expects a class variable REPO_SCHEMA to be a GitRepoSchema string in the form
described by that class. described by that class.
@@ -487,61 +489,62 @@ class GitRepoSchemaTestBase(unittest.TestCase):
You probably will end up using either GitRepoReadOnlyTestBase or You probably will end up using either GitRepoReadOnlyTestBase or
GitRepoReadWriteTestBase for real tests. GitRepoReadWriteTestBase for real tests.
""" """
REPO_SCHEMA = None REPO_SCHEMA = None
@classmethod @classmethod
def getRepoContent(cls, commit): def getRepoContent(cls, commit):
commit = 'COMMIT_%s' % commit commit = 'COMMIT_%s' % commit
return getattr(cls, commit, None) return getattr(cls, commit, None)
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super(GitRepoSchemaTestBase, cls).setUpClass() super(GitRepoSchemaTestBase, cls).setUpClass()
assert cls.REPO_SCHEMA is not None assert cls.REPO_SCHEMA is not None
cls.r_schema = GitRepoSchema(cls.REPO_SCHEMA, cls.getRepoContent) cls.r_schema = GitRepoSchema(cls.REPO_SCHEMA, cls.getRepoContent)
class GitRepoReadOnlyTestBase(GitRepoSchemaTestBase): class GitRepoReadOnlyTestBase(GitRepoSchemaTestBase):
"""Injects a GitRepo object given the schema and content from """Injects a GitRepo object given the schema and content from
GitRepoSchemaTestBase into TestCase classes which subclass this. GitRepoSchemaTestBase into TestCase classes which subclass this.
This GitRepo will appear as self.repo, and will be deleted and recreated once This GitRepo will appear as self.repo, and will be deleted and recreated once
for the duration of all the tests in the subclass. for the duration of all the tests in the subclass.
""" """
REPO_SCHEMA = None REPO_SCHEMA = None
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super(GitRepoReadOnlyTestBase, cls).setUpClass() super(GitRepoReadOnlyTestBase, cls).setUpClass()
assert cls.REPO_SCHEMA is not None assert cls.REPO_SCHEMA is not None
cls.repo = cls.r_schema.reify() cls.repo = cls.r_schema.reify()
def setUp(self): def setUp(self):
self.repo.git('checkout', '-f', self.repo.last_commit) self.repo.git('checkout', '-f', self.repo.last_commit)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
cls.repo.nuke() cls.repo.nuke()
super(GitRepoReadOnlyTestBase, cls).tearDownClass() super(GitRepoReadOnlyTestBase, cls).tearDownClass()
class GitRepoReadWriteTestBase(GitRepoSchemaTestBase): class GitRepoReadWriteTestBase(GitRepoSchemaTestBase):
"""Injects a GitRepo object given the schema and content from """Injects a GitRepo object given the schema and content from
GitRepoSchemaTestBase into TestCase classes which subclass this. GitRepoSchemaTestBase into TestCase classes which subclass this.
This GitRepo will appear as self.repo, and will be deleted and recreated for This GitRepo will appear as self.repo, and will be deleted and recreated for
each test function in the subclass. each test function in the subclass.
""" """
REPO_SCHEMA = None REPO_SCHEMA = None
def setUp(self): def setUp(self):
super(GitRepoReadWriteTestBase, self).setUp() super(GitRepoReadWriteTestBase, self).setUp()
self.repo = self.r_schema.reify() self.repo = self.r_schema.reify()
def tearDown(self): def tearDown(self):
self.repo.nuke() self.repo.nuke()
super(GitRepoReadWriteTestBase, self).tearDown() super(GitRepoReadWriteTestBase, self).tearDown()
def assertSchema(self, schema_string): def assertSchema(self, schema_string):
self.assertEqual(GitRepoSchema(schema_string).simple_graph(), self.assertEqual(
self.repo.to_schema().simple_graph()) GitRepoSchema(schema_string).simple_graph(),
self.repo.to_schema().simple_graph())

View File

@@ -15,10 +15,12 @@ from presubmit_canned_checks import _ReportErrorFileAndLine
class MockCannedChecks(object): class MockCannedChecks(object):
def _FindNewViolationsOfRule(self, callable_rule, input_api, def _FindNewViolationsOfRule(self,
source_file_filter=None, callable_rule,
error_formatter=_ReportErrorFileAndLine): input_api,
"""Find all newly introduced violations of a per-line rule (a callable). source_file_filter=None,
error_formatter=_ReportErrorFileAndLine):
"""Find all newly introduced violations of a per-line rule (a callable).
Arguments: Arguments:
callable_rule: a callable taking a file extension and line of input and callable_rule: a callable taking a file extension and line of input and
@@ -32,232 +34,246 @@ class MockCannedChecks(object):
Returns: Returns:
A list of the newly-introduced violations reported by the rule. A list of the newly-introduced violations reported by the rule.
""" """
errors = [] errors = []
for f in input_api.AffectedFiles(include_deletes=False, for f in input_api.AffectedFiles(include_deletes=False,
file_filter=source_file_filter): file_filter=source_file_filter):
# For speed, we do two passes, checking first the full file. Shelling out # For speed, we do two passes, checking first the full file.
# to the SCM to determine the changed region can be quite expensive on # Shelling out to the SCM to determine the changed region can be
# Win32. Assuming that most files will be kept problem-free, we can # quite expensive on Win32. Assuming that most files will be kept
# skip the SCM operations most of the time. # problem-free, we can skip the SCM operations most of the time.
extension = str(f.LocalPath()).rsplit('.', 1)[-1] extension = str(f.LocalPath()).rsplit('.', 1)[-1]
if all(callable_rule(extension, line) for line in f.NewContents()): if all(callable_rule(extension, line) for line in f.NewContents()):
continue # No violation found in full text: can skip considering diff. # No violation found in full text: can skip considering diff.
continue
for line_num, line in f.ChangedContents(): for line_num, line in f.ChangedContents():
if not callable_rule(extension, line): if not callable_rule(extension, line):
errors.append(error_formatter(f.LocalPath(), line_num, line)) errors.append(error_formatter(f.LocalPath(), line_num,
line))
return errors return errors
class MockInputApi(object): class MockInputApi(object):
"""Mock class for the InputApi class. """Mock class for the InputApi class.
This class can be used for unittests for presubmit by initializing the files This class can be used for unittests for presubmit by initializing the files
attribute as the list of changed files. attribute as the list of changed files.
""" """
DEFAULT_FILES_TO_SKIP = () DEFAULT_FILES_TO_SKIP = ()
def __init__(self): def __init__(self):
self.canned_checks = MockCannedChecks() self.canned_checks = MockCannedChecks()
self.fnmatch = fnmatch self.fnmatch = fnmatch
self.json = json self.json = json
self.re = re self.re = re
self.os_path = os.path self.os_path = os.path
self.platform = sys.platform self.platform = sys.platform
self.python_executable = sys.executable self.python_executable = sys.executable
self.platform = sys.platform self.platform = sys.platform
self.subprocess = subprocess self.subprocess = subprocess
self.sys = sys self.sys = sys
self.files = [] self.files = []
self.is_committing = False self.is_committing = False
self.no_diffs = False self.no_diffs = False
self.change = MockChange([]) self.change = MockChange([])
self.presubmit_local_path = os.path.dirname(__file__) self.presubmit_local_path = os.path.dirname(__file__)
self.logging = logging.getLogger('PRESUBMIT') self.logging = logging.getLogger('PRESUBMIT')
def CreateMockFileInPath(self, f_list): def CreateMockFileInPath(self, f_list):
self.os_path.exists = lambda x: x in f_list self.os_path.exists = lambda x: x in f_list
def AffectedFiles(self, file_filter=None, include_deletes=True): def AffectedFiles(self, file_filter=None, include_deletes=True):
for file in self.files: # pylint: disable=redefined-builtin for file in self.files: # pylint: disable=redefined-builtin
if file_filter and not file_filter(file): if file_filter and not file_filter(file):
continue continue
if not include_deletes and file.Action() == 'D': if not include_deletes and file.Action() == 'D':
continue continue
yield file yield file
def AffectedSourceFiles(self, file_filter=None): def AffectedSourceFiles(self, file_filter=None):
return self.AffectedFiles(file_filter=file_filter) return self.AffectedFiles(file_filter=file_filter)
def FilterSourceFile(self, file, # pylint: disable=redefined-builtin def FilterSourceFile(
files_to_check=(), files_to_skip=()): self,
local_path = file.LocalPath() file, # pylint: disable=redefined-builtin
found_in_files_to_check = not files_to_check files_to_check=(),
if files_to_check: files_to_skip=()):
if isinstance(files_to_check, str): local_path = file.LocalPath()
raise TypeError('files_to_check should be an iterable of strings') found_in_files_to_check = not files_to_check
for pattern in files_to_check: if files_to_check:
compiled_pattern = re.compile(pattern) if isinstance(files_to_check, str):
if compiled_pattern.search(local_path): raise TypeError(
found_in_files_to_check = True 'files_to_check should be an iterable of strings')
break for pattern in files_to_check:
if files_to_skip: compiled_pattern = re.compile(pattern)
if isinstance(files_to_skip, str): if compiled_pattern.search(local_path):
raise TypeError('files_to_skip should be an iterable of strings') found_in_files_to_check = True
for pattern in files_to_skip: break
compiled_pattern = re.compile(pattern) if files_to_skip:
if compiled_pattern.search(local_path): if isinstance(files_to_skip, str):
return False raise TypeError(
return found_in_files_to_check 'files_to_skip should be an iterable of strings')
for pattern in files_to_skip:
compiled_pattern = re.compile(pattern)
if compiled_pattern.search(local_path):
return False
return found_in_files_to_check
def LocalPaths(self): def LocalPaths(self):
return [file.LocalPath() for file in self.files] # pylint: disable=redefined-builtin return [file.LocalPath() for file in self.files] # pylint: disable=redefined-builtin
def PresubmitLocalPath(self): def PresubmitLocalPath(self):
return self.presubmit_local_path return self.presubmit_local_path
def ReadFile(self, filename, mode='rU'): def ReadFile(self, filename, mode='rU'):
if hasattr(filename, 'AbsoluteLocalPath'): if hasattr(filename, 'AbsoluteLocalPath'):
filename = filename.AbsoluteLocalPath() filename = filename.AbsoluteLocalPath()
for file_ in self.files: for file_ in self.files:
if file_.LocalPath() == filename: if file_.LocalPath() == filename:
return '\n'.join(file_.NewContents()) return '\n'.join(file_.NewContents())
# Otherwise, file is not in our mock API. # Otherwise, file is not in our mock API.
raise IOError("No such file or directory: '%s'" % filename) raise IOError("No such file or directory: '%s'" % filename)
class MockOutputApi(object): class MockOutputApi(object):
"""Mock class for the OutputApi class. """Mock class for the OutputApi class.
An instance of this class can be passed to presubmit unittests for outputing An instance of this class can be passed to presubmit unittests for outputing
various types of results. various types of results.
""" """
class PresubmitResult(object):
def __init__(self, message, items=None, long_text=''):
self.message = message
self.items = items
self.long_text = long_text
class PresubmitResult(object): def __repr__(self):
def __init__(self, message, items=None, long_text=''): return self.message
self.message = message
self.items = items
self.long_text = long_text
def __repr__(self): class PresubmitError(PresubmitResult):
return self.message def __init__(self, message, items=None, long_text=''):
MockOutputApi.PresubmitResult.__init__(self, message, items,
long_text)
self.type = 'error'
class PresubmitError(PresubmitResult): class PresubmitPromptWarning(PresubmitResult):
def __init__(self, message, items=None, long_text=''): def __init__(self, message, items=None, long_text=''):
MockOutputApi.PresubmitResult.__init__(self, message, items, long_text) MockOutputApi.PresubmitResult.__init__(self, message, items,
self.type = 'error' long_text)
self.type = 'warning'
class PresubmitPromptWarning(PresubmitResult): class PresubmitNotifyResult(PresubmitResult):
def __init__(self, message, items=None, long_text=''): def __init__(self, message, items=None, long_text=''):
MockOutputApi.PresubmitResult.__init__(self, message, items, long_text) MockOutputApi.PresubmitResult.__init__(self, message, items,
self.type = 'warning' long_text)
self.type = 'notify'
class PresubmitNotifyResult(PresubmitResult): class PresubmitPromptOrNotify(PresubmitResult):
def __init__(self, message, items=None, long_text=''): def __init__(self, message, items=None, long_text=''):
MockOutputApi.PresubmitResult.__init__(self, message, items, long_text) MockOutputApi.PresubmitResult.__init__(self, message, items,
self.type = 'notify' long_text)
self.type = 'promptOrNotify'
class PresubmitPromptOrNotify(PresubmitResult): def __init__(self):
def __init__(self, message, items=None, long_text=''): self.more_cc = []
MockOutputApi.PresubmitResult.__init__(self, message, items, long_text)
self.type = 'promptOrNotify'
def __init__(self): def AppendCC(self, more_cc):
self.more_cc = [] self.more_cc.extend(more_cc)
def AppendCC(self, more_cc):
self.more_cc.extend(more_cc)
class MockFile(object): class MockFile(object):
"""Mock class for the File class. """Mock class for the File class.
This class can be used to form the mock list of changed files in This class can be used to form the mock list of changed files in
MockInputApi for presubmit unittests. MockInputApi for presubmit unittests.
""" """
def __init__(self,
local_path,
new_contents,
old_contents=None,
action='A',
scm_diff=None):
self._local_path = local_path
self._new_contents = new_contents
self._changed_contents = [(i + 1, l)
for i, l in enumerate(new_contents)]
self._action = action
if scm_diff:
self._scm_diff = scm_diff
else:
self._scm_diff = ("--- /dev/null\n+++ %s\n@@ -0,0 +1,%d @@\n" %
(local_path, len(new_contents)))
for l in new_contents:
self._scm_diff += "+%s\n" % l
self._old_contents = old_contents
def __init__(self, local_path, new_contents, old_contents=None, action='A', def Action(self):
scm_diff=None): return self._action
self._local_path = local_path
self._new_contents = new_contents
self._changed_contents = [(i + 1, l) for i, l in enumerate(new_contents)]
self._action = action
if scm_diff:
self._scm_diff = scm_diff
else:
self._scm_diff = (
"--- /dev/null\n+++ %s\n@@ -0,0 +1,%d @@\n" %
(local_path, len(new_contents)))
for l in new_contents:
self._scm_diff += "+%s\n" % l
self._old_contents = old_contents
def Action(self): def ChangedContents(self):
return self._action return self._changed_contents
def ChangedContents(self): def NewContents(self):
return self._changed_contents return self._new_contents
def NewContents(self): def LocalPath(self):
return self._new_contents return self._local_path
def LocalPath(self): def AbsoluteLocalPath(self):
return self._local_path return self._local_path
def AbsoluteLocalPath(self): def GenerateScmDiff(self):
return self._local_path return self._scm_diff
def GenerateScmDiff(self): def OldContents(self):
return self._scm_diff return self._old_contents
def OldContents(self): def rfind(self, p):
return self._old_contents """os.path.basename is used on MockFile so we need an rfind method."""
return self._local_path.rfind(p)
def rfind(self, p): def __getitem__(self, i):
"""os.path.basename is called on MockFile so we need an rfind method.""" """os.path.basename is used on MockFile so we need a get method."""
return self._local_path.rfind(p) return self._local_path[i]
def __getitem__(self, i): def __len__(self):
"""os.path.basename is called on MockFile so we need a get method.""" """os.path.basename is used on MockFile so we need a len method."""
return self._local_path[i] return len(self._local_path)
def __len__(self): def replace(self, altsep, sep):
"""os.path.basename is called on MockFile so we need a len method.""" """os.path.basename is used on MockFile so we need a replace method."""
return len(self._local_path) return self._local_path.replace(altsep, sep)
def replace(self, altsep, sep):
"""os.path.basename is called on MockFile so we need a replace method."""
return self._local_path.replace(altsep, sep)
class MockAffectedFile(MockFile): class MockAffectedFile(MockFile):
def AbsoluteLocalPath(self): def AbsoluteLocalPath(self):
return self._local_path return self._local_path
class MockChange(object): class MockChange(object):
"""Mock class for Change class. """Mock class for Change class.
This class can be used in presubmit unittests to mock the query of the This class can be used in presubmit unittests to mock the query of the
current change. current change.
""" """
def __init__(self, changed_files, description=''):
self._changed_files = changed_files
self.footers = defaultdict(list)
self._description = description
def __init__(self, changed_files, description=''): def LocalPaths(self):
self._changed_files = changed_files return self._changed_files
self.footers = defaultdict(list)
self._description = description
def LocalPaths(self): def AffectedFiles(self,
return self._changed_files include_dirs=False,
include_deletes=True,
file_filter=None):
return self._changed_files
def AffectedFiles(self, include_dirs=False, include_deletes=True, def GitFootersFromDescription(self):
file_filter=None): return self.footers
return self._changed_files
def GitFootersFromDescription(self): def DescriptionText(self):
return self.footers return self._description
def DescriptionText(self):
return self._description

View File

@@ -2,7 +2,6 @@
# Copyright (c) 2019 The Chromium Authors. All rights reserved. # Copyright (c) 2019 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be # Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file. # found in the LICENSE file.
"""Script used to test subprocess2.""" """Script used to test subprocess2."""
import optparse import optparse
@@ -10,51 +9,52 @@ import os
import sys import sys
import time import time
if sys.platform == 'win32': if sys.platform == 'win32':
# Annoying, make sure the output is not translated on Windows. # Annoying, make sure the output is not translated on Windows.
# pylint: disable=no-member,import-error # pylint: disable=no-member,import-error
import msvcrt import msvcrt
msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY) msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
msvcrt.setmode(sys.stderr.fileno(), os.O_BINARY) msvcrt.setmode(sys.stderr.fileno(), os.O_BINARY)
parser = optparse.OptionParser() parser = optparse.OptionParser()
parser.add_option( parser.add_option('--fail',
'--fail', dest='return_value',
dest='return_value', action='store_const',
action='store_const', default=0,
default=0, const=64)
const=64) parser.add_option('--crlf',
parser.add_option( action='store_const',
'--crlf', action='store_const', const='\r\n', dest='eol', default='\n') const='\r\n',
parser.add_option( dest='eol',
'--cr', action='store_const', const='\r', dest='eol') default='\n')
parser.add_option('--cr', action='store_const', const='\r', dest='eol')
parser.add_option('--stdout', action='store_true') parser.add_option('--stdout', action='store_true')
parser.add_option('--stderr', action='store_true') parser.add_option('--stderr', action='store_true')
parser.add_option('--read', action='store_true') parser.add_option('--read', action='store_true')
options, args = parser.parse_args() options, args = parser.parse_args()
if args: if args:
parser.error('Internal error') parser.error('Internal error')
def do(string): def do(string):
if options.stdout: if options.stdout:
sys.stdout.buffer.write(string.upper().encode('utf-8')) sys.stdout.buffer.write(string.upper().encode('utf-8'))
sys.stdout.buffer.write(options.eol.encode('utf-8')) sys.stdout.buffer.write(options.eol.encode('utf-8'))
if options.stderr: if options.stderr:
sys.stderr.buffer.write(string.lower().encode('utf-8')) sys.stderr.buffer.write(string.lower().encode('utf-8'))
sys.stderr.buffer.write(options.eol.encode('utf-8')) sys.stderr.buffer.write(options.eol.encode('utf-8'))
sys.stderr.flush() sys.stderr.flush()
do('A') do('A')
do('BB') do('BB')
do('CCC') do('CCC')
if options.read: if options.read:
assert options.return_value == 0 assert options.return_value == 0
try: try:
while sys.stdin.read(1): while sys.stdin.read(1):
options.return_value += 1 options.return_value += 1
except OSError: except OSError:
pass pass
sys.exit(options.return_value) sys.exit(options.return_value)

View File

@@ -1,7 +1,6 @@
# Copyright (c) 2019 The Chromium Authors. All rights reserved. # Copyright (c) 2019 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be # Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file. # found in the LICENSE file.
"""Simplify unit tests based on pymox.""" """Simplify unit tests based on pymox."""
from __future__ import print_function from __future__ import print_function
@@ -12,51 +11,52 @@ import string
class TestCaseUtils(object): class TestCaseUtils(object):
"""Base class with some additional functionalities. People will usually want """Base class with some additional functionalities. People will usually want
to use SuperMoxTestBase instead.""" to use SuperMoxTestBase instead."""
# Backup the separator in case it gets mocked # Backup the separator in case it gets mocked
_OS_SEP = os.sep _OS_SEP = os.sep
_RANDOM_CHOICE = random.choice _RANDOM_CHOICE = random.choice
_RANDOM_RANDINT = random.randint _RANDOM_RANDINT = random.randint
_STRING_LETTERS = string.ascii_letters _STRING_LETTERS = string.ascii_letters
## Some utilities for generating arbitrary arguments. ## Some utilities for generating arbitrary arguments.
def String(self, max_length): def String(self, max_length):
return ''.join([self._RANDOM_CHOICE(self._STRING_LETTERS) return ''.join([
for _ in range(self._RANDOM_RANDINT(1, max_length))]) self._RANDOM_CHOICE(self._STRING_LETTERS)
for _ in range(self._RANDOM_RANDINT(1, max_length))
])
def Strings(self, max_arg_count, max_arg_length): def Strings(self, max_arg_count, max_arg_length):
return [self.String(max_arg_length) for _ in range(max_arg_count)] return [self.String(max_arg_length) for _ in range(max_arg_count)]
def Args(self, max_arg_count=8, max_arg_length=16): def Args(self, max_arg_count=8, max_arg_length=16):
return self.Strings(max_arg_count, return self.Strings(max_arg_count,
self._RANDOM_RANDINT(1, max_arg_length)) self._RANDOM_RANDINT(1, max_arg_length))
def _DirElts(self, max_elt_count=4, max_elt_length=8): def _DirElts(self, max_elt_count=4, max_elt_length=8):
return self._OS_SEP.join(self.Strings(max_elt_count, max_elt_length)) return self._OS_SEP.join(self.Strings(max_elt_count, max_elt_length))
def Dir(self, max_elt_count=4, max_elt_length=8): def Dir(self, max_elt_count=4, max_elt_length=8):
return (self._RANDOM_CHOICE((self._OS_SEP, '')) + return (self._RANDOM_CHOICE(
self._DirElts(max_elt_count, max_elt_length)) (self._OS_SEP, '')) + self._DirElts(max_elt_count, max_elt_length))
def RootDir(self, max_elt_count=4, max_elt_length=8): def RootDir(self, max_elt_count=4, max_elt_length=8):
return self._OS_SEP + self._DirElts(max_elt_count, max_elt_length) return self._OS_SEP + self._DirElts(max_elt_count, max_elt_length)
def compareMembers(self, obj, members): def compareMembers(self, obj, members):
"""If you add a member, be sure to add the relevant test!""" """If you add a member, be sure to add the relevant test!"""
# Skip over members starting with '_' since they are usually not meant to # Skip over members starting with '_' since they are usually not meant
# be for public use. # to be for public use.
actual_members = [x for x in sorted(dir(obj)) actual_members = [x for x in sorted(dir(obj)) if not x.startswith('_')]
if not x.startswith('_')] expected_members = sorted(members)
expected_members = sorted(members) if actual_members != expected_members:
if actual_members != expected_members: diff = ([i for i in actual_members if i not in expected_members] +
diff = ([i for i in actual_members if i not in expected_members] + [i for i in expected_members if i not in actual_members])
[i for i in expected_members if i not in actual_members]) print(diff, file=sys.stderr)
print(diff, file=sys.stderr) # pylint: disable=no-member
# pylint: disable=no-member self.assertEqual(actual_members, expected_members)
self.assertEqual(actual_members, expected_members)
def setUp(self): def setUp(self):
self.root_dir = self.Dir() self.root_dir = self.Dir()
self.args = self.Args() self.args = self.Args()
self.relpath = self.String(200) self.relpath = self.String(200)

View File

@@ -15,83 +15,84 @@ import gclient_utils
class TrialDir(object): class TrialDir(object):
"""Manages a temporary directory. """Manages a temporary directory.
On first object creation, TrialDir.TRIAL_ROOT will be set to a new temporary On first object creation, TrialDir.TRIAL_ROOT will be set to a new temporary
directory created in /tmp or the equivalent. It will be deleted on process directory created in /tmp or the equivalent. It will be deleted on process
exit unless TrialDir.SHOULD_LEAK is set to True. exit unless TrialDir.SHOULD_LEAK is set to True.
""" """
# When SHOULD_LEAK is set to True, temporary directories created while the # When SHOULD_LEAK is set to True, temporary directories created while the
# tests are running aren't deleted at the end of the tests. Expect failures # tests are running aren't deleted at the end of the tests. Expect failures
# when running more than one test due to inter-test side-effects. Helps with # when running more than one test due to inter-test side-effects. Helps with
# debugging. # debugging.
SHOULD_LEAK = False SHOULD_LEAK = False
# Main root directory. # Main root directory.
TRIAL_ROOT = None TRIAL_ROOT = None
def __init__(self, subdir, leak=False): def __init__(self, subdir, leak=False):
self.leak = self.SHOULD_LEAK or leak self.leak = self.SHOULD_LEAK or leak
self.subdir = subdir self.subdir = subdir
self.root_dir = None self.root_dir = None
def set_up(self): def set_up(self):
"""All late initialization comes here.""" """All late initialization comes here."""
# You can override self.TRIAL_ROOT. # You can override self.TRIAL_ROOT.
if not self.TRIAL_ROOT: if not self.TRIAL_ROOT:
# Was not yet initialized. # Was not yet initialized.
TrialDir.TRIAL_ROOT = os.path.realpath(tempfile.mkdtemp(prefix='trial')) TrialDir.TRIAL_ROOT = os.path.realpath(
atexit.register(self._clean) tempfile.mkdtemp(prefix='trial'))
self.root_dir = os.path.join(TrialDir.TRIAL_ROOT, self.subdir) atexit.register(self._clean)
gclient_utils.rmtree(self.root_dir) self.root_dir = os.path.join(TrialDir.TRIAL_ROOT, self.subdir)
os.makedirs(self.root_dir) gclient_utils.rmtree(self.root_dir)
os.makedirs(self.root_dir)
def tear_down(self): def tear_down(self):
"""Cleans the trial subdirectory for this instance.""" """Cleans the trial subdirectory for this instance."""
if not self.leak: if not self.leak:
logging.debug('Removing %s' % self.root_dir) logging.debug('Removing %s' % self.root_dir)
gclient_utils.rmtree(self.root_dir) gclient_utils.rmtree(self.root_dir)
else: else:
logging.error('Leaking %s' % self.root_dir) logging.error('Leaking %s' % self.root_dir)
self.root_dir = None self.root_dir = None
@staticmethod @staticmethod
def _clean(): def _clean():
"""Cleans the root trial directory.""" """Cleans the root trial directory."""
if not TrialDir.SHOULD_LEAK: if not TrialDir.SHOULD_LEAK:
logging.debug('Removing %s' % TrialDir.TRIAL_ROOT) logging.debug('Removing %s' % TrialDir.TRIAL_ROOT)
gclient_utils.rmtree(TrialDir.TRIAL_ROOT) gclient_utils.rmtree(TrialDir.TRIAL_ROOT)
else: else:
logging.error('Leaking %s' % TrialDir.TRIAL_ROOT) logging.error('Leaking %s' % TrialDir.TRIAL_ROOT)
class TrialDirMixIn(object): class TrialDirMixIn(object):
def setUp(self): def setUp(self):
# Create a specific directory just for the test. # Create a specific directory just for the test.
self.trial = TrialDir(self.id()) self.trial = TrialDir(self.id())
self.trial.set_up() self.trial.set_up()
def tearDown(self): def tearDown(self):
self.trial.tear_down() self.trial.tear_down()
@property @property
def root_dir(self): def root_dir(self):
return self.trial.root_dir return self.trial.root_dir
class TestCase(unittest.TestCase, TrialDirMixIn): class TestCase(unittest.TestCase, TrialDirMixIn):
"""Base unittest class that cleans off a trial directory in tearDown().""" """Base unittest class that cleans off a trial directory in tearDown()."""
def setUp(self): def setUp(self):
unittest.TestCase.setUp(self) unittest.TestCase.setUp(self)
TrialDirMixIn.setUp(self) TrialDirMixIn.setUp(self)
def tearDown(self): def tearDown(self):
TrialDirMixIn.tearDown(self) TrialDirMixIn.tearDown(self)
unittest.TestCase.tearDown(self) unittest.TestCase.tearDown(self)
if '-l' in sys.argv: if '-l' in sys.argv:
# See SHOULD_LEAK definition in TrialDir for its purpose. # See SHOULD_LEAK definition in TrialDir for its purpose.
TrialDir.SHOULD_LEAK = True TrialDir.SHOULD_LEAK = True
print('Leaking!') print('Leaking!')
sys.argv.remove('-l') sys.argv.remove('-l')