1
0
mirror of synced 2024-11-22 09:14:23 +01:00

MAT2's cli now uses meaningful return codes

- Simplify the multiprocessing by using a Pool
- Use some functional (♥) constructions to exit
  with a return code
- Add some tests to prove that we're doing things
  that are working correctly
This commit is contained in:
jvoisin 2018-04-29 22:55:26 +02:00
parent a79c9410af
commit d2b2a54a72
2 changed files with 28 additions and 30 deletions

42
main.py
View File

@ -1,11 +1,12 @@
#!/usr/bin/python3 #!/usr/bin/python3
import os import os
from typing import Tuple
import sys
import itertools
import mimetypes import mimetypes
import argparse import argparse
from threading import Thread
import multiprocessing import multiprocessing
from queue import Queue
from src import parser_factory from src import parser_factory
@ -52,14 +53,15 @@ def show_meta(filename:str):
print(" %s: harmful content" % k) print(" %s: harmful content" % k)
def clean_meta(filename:str, is_lightweigth:bool) -> bool: def clean_meta(params:Tuple[str, bool]) -> bool:
filename, is_lightweigth = params
if not __check_file(filename, os.R_OK|os.W_OK): if not __check_file(filename, os.R_OK|os.W_OK):
return return
p, mtype = parser_factory.get_parser(filename) p, mtype = parser_factory.get_parser(filename)
if p is None: if p is None:
print("[-] %s's format (%s) is not supported" % (filename, mtype)) print("[-] %s's format (%s) is not supported" % (filename, mtype))
return return False
if is_lightweigth: if is_lightweigth:
return p.remove_all_lightweight() return p.remove_all_lightweight()
return p.remove_all() return p.remove_all()
@ -82,15 +84,6 @@ def __get_files_recursively(files):
for _f in _files: for _f in _files:
yield os.path.join(path, _f) yield os.path.join(path, _f)
def __do_clean_async(is_lightweigth, q):
while True:
f = q.get()
if f is None: # nothing more to process
return
clean_meta(f, is_lightweigth)
q.task_done()
def main(): def main():
arg_parser = create_arg_parser() arg_parser = create_arg_parser()
args = arg_parser.parse_args() args = arg_parser.parse_args()
@ -106,24 +99,13 @@ def main():
show_meta(f) show_meta(f)
return return
else: # Thread the cleaning else:
p = multiprocessing.Pool()
mode = (args.lightweight is True) mode = (args.lightweight is True)
q = Queue(maxsize=0) l = zip(__get_files_recursively(args.files), itertools.repeat(mode))
threads = list()
for f in __get_files_recursively(args.files):
q.put(f)
for _ in range(multiprocessing.cpu_count()):
worker = Thread(target=__do_clean_async, args=(mode, q))
worker.start()
threads.append(worker)
for _ in range(multiprocessing.cpu_count()):
q.put(None) # stop the threads
for worker in threads:
worker.join()
ret = list(p.imap_unordered(clean_meta, list(l)))
return 0 if all(ret) else -1
if __name__ == '__main__': if __name__ == '__main__':
main() sys.exit(main())

View File

@ -16,6 +16,22 @@ class TestHelp(unittest.TestCase):
self.assertIn(b'usage: main.py [-h] [-c] [-l] [-s] [-L] [files [files ...]]', stdout) self.assertIn(b'usage: main.py [-h] [-c] [-l] [-s] [-L] [files [files ...]]', stdout)
class TestReturnValue(unittest.TestCase):
def test_nonzero(self):
ret = subprocess.call(['./main.py', './main.py'], stdout=subprocess.DEVNULL)
self.assertEqual(255, ret)
ret = subprocess.call(['./main.py', '--whololo'], stderr=subprocess.DEVNULL)
self.assertEqual(2, ret)
def test_zero(self):
ret = subprocess.call(['./main.py'], stdout=subprocess.DEVNULL)
self.assertEqual(0, ret)
ret = subprocess.call(['./main.py', '--show', './main.py'], stdout=subprocess.DEVNULL)
self.assertEqual(0, ret)
class TestCleanMeta(unittest.TestCase): class TestCleanMeta(unittest.TestCase):
def test_jpg(self): def test_jpg(self):
shutil.copy('./tests/data/dirty.jpg', './tests/data/clean.jpg') shutil.copy('./tests/data/dirty.jpg', './tests/data/clean.jpg')