diff --git a/libmat2/__init__.py b/libmat2/__init__.py index bf4e813..8a5b064 100644 --- a/libmat2/__init__.py +++ b/libmat2/__init__.py @@ -2,6 +2,7 @@ import os import collections +from enum import Enum import importlib from typing import Dict, Optional @@ -62,3 +63,8 @@ def check_dependencies() -> dict: ret[value] = False # pragma: no cover return ret + +class UnknownMemberPolicy(Enum): + ABORT = 'abort' + OMIT = 'omit' + KEEP = 'keep' diff --git a/libmat2/office.py b/libmat2/office.py index 29100df..60c5478 100644 --- a/libmat2/office.py +++ b/libmat2/office.py @@ -9,7 +9,7 @@ from typing import Dict, Set, Pattern import xml.etree.ElementTree as ET # type: ignore -from . import abstract, parser_factory +from . import abstract, parser_factory, UnknownMemberPolicy # Make pyflakes happy assert Set @@ -37,8 +37,8 @@ class ArchiveBasedAbstractParser(abstract.AbstractParser): files_to_omit = set() # type: Set[Pattern] # what should the parser do if it encounters an unknown file in - # the archive? valid policies are 'abort', 'omit', 'keep' - unknown_member_policy = 'abort' # type: str + # the archive? + unknown_member_policy = UnknownMemberPolicy.ABORT # type: UnknownMemberPolicy def __init__(self, filename): super().__init__(filename) @@ -81,10 +81,6 @@ class ArchiveBasedAbstractParser(abstract.AbstractParser): def remove_all(self) -> bool: # pylint: disable=too-many-branches - if self.unknown_member_policy not in ['omit', 'keep', 'abort']: - logging.error("The policy %s is invalid.", self.unknown_member_policy) - raise ValueError - with zipfile.ZipFile(self.filename) as zin,\ zipfile.ZipFile(self.output_filename, 'w') as zout: @@ -113,11 +109,11 @@ class ArchiveBasedAbstractParser(abstract.AbstractParser): # supported files that we want to clean then add tmp_parser, mtype = parser_factory.get_parser(full_path) # type: ignore if not tmp_parser: - if self.unknown_member_policy == 'omit': + if self.unknown_member_policy == UnknownMemberPolicy.OMIT: logging.warning("In file %s, omitting unknown element %s (format: %s)", self.filename, item.filename, mtype) continue - elif self.unknown_member_policy == 'keep': + elif self.unknown_member_policy == UnknownMemberPolicy.KEEP: logging.warning("In file %s, keeping unknown element %s (format: %s)", self.filename, item.filename, mtype) else: diff --git a/mat2 b/mat2 index 2a8ef46..0aba8d1 100755 --- a/mat2 +++ b/mat2 @@ -10,7 +10,8 @@ import multiprocessing import logging try: - from libmat2 import parser_factory, UNSUPPORTED_EXTENSIONS, check_dependencies + from libmat2 import (parser_factory, UNSUPPORTED_EXTENSIONS, check_dependencies, + UnknownMemberPolicy) except ValueError as e: print(e) sys.exit(1) @@ -42,8 +43,8 @@ def create_arg_parser(): parser.add_argument('-V', '--verbose', action='store_true', help='show more verbose status information') parser.add_argument('--unknown-members', metavar='policy', default='abort', - help='how to handle unknown members of archive-style files ' + - '(policy should be abort, omit, or keep)') + help='how to handle unknown members of archive-style files (policy should' + + ' be one of: ' + ', '.join([x.value for x in UnknownMemberPolicy]) + ')') info = parser.add_mutually_exclusive_group() @@ -70,7 +71,7 @@ def show_meta(filename: str): except UnicodeEncodeError: print(" %s: harmful content" % k) -def clean_meta(params: Tuple[str, bool, str]) -> bool: +def clean_meta(params: Tuple[str, bool, UnknownMemberPolicy]) -> bool: filename, is_lightweight, unknown_member_policy = params if not __check_file(filename, os.R_OK|os.W_OK): return False @@ -137,15 +138,13 @@ def main(): return 0 else: - if args.unknown_members == 'keep': + unknown_member_policy = UnknownMemberPolicy(args.unknown_members) + if unknown_member_policy == UnknownMemberPolicy.KEEP: logging.warning('Keeping unknown member files may leak metadata in the resulting file!') - elif args.unknown_members not in ['omit', 'abort']: - logging.warning('Undefined policy for handling unknown member files: "%s"', - args.unknown_members) p = multiprocessing.Pool() mode = (args.lightweight is True) l = zip(__get_files_recursively(args.files), itertools.repeat(mode), - itertools.repeat(args.unknown_members)) + itertools.repeat(unknown_member_policy)) ret = list(p.imap_unordered(clean_meta, list(l))) return 0 if all(ret) else -1 diff --git a/tests/test_policy.py b/tests/test_policy.py index 39282b1..5a8447b 100644 --- a/tests/test_policy.py +++ b/tests/test_policy.py @@ -4,28 +4,29 @@ import unittest import shutil import os -from libmat2 import office +from libmat2 import office, UnknownMemberPolicy class TestPolicy(unittest.TestCase): def test_policy_omit(self): shutil.copy('./tests/data/embedded.docx', './tests/data/clean.docx') p = office.MSOfficeParser('./tests/data/clean.docx') - p.unknown_member_policy = 'omit' + p.unknown_member_policy = UnknownMemberPolicy.OMIT self.assertTrue(p.remove_all()) os.remove('./tests/data/clean.docx') + os.remove('./tests/data/clean.cleaned.docx') def test_policy_keep(self): shutil.copy('./tests/data/embedded.docx', './tests/data/clean.docx') p = office.MSOfficeParser('./tests/data/clean.docx') - p.unknown_member_policy = 'keep' + p.unknown_member_policy = UnknownMemberPolicy.KEEP self.assertTrue(p.remove_all()) os.remove('./tests/data/clean.docx') + os.remove('./tests/data/clean.cleaned.docx') def test_policy_unknown(self): shutil.copy('./tests/data/embedded.docx', './tests/data/clean.docx') p = office.MSOfficeParser('./tests/data/clean.docx') - p.unknown_member_policy = 'unknown_policy_name_totally_invalid' with self.assertRaises(ValueError): - p.remove_all() + p.unknown_member_policy = UnknownMemberPolicy('unknown_policy_name_totally_invalid') os.remove('./tests/data/clean.docx')