diff --git a/scripts/data.py b/scripts/data.py index 151e179..0873d1d 100644 --- a/scripts/data.py +++ b/scripts/data.py @@ -3,19 +3,23 @@ import tempfile import os.path import os import shutil +import argparse +import json -print(sys.argv[-1]) +parser = argparse.ArgumentParser(description='Tool for fixing data.yaml and creating config json') +parser.add_argument('yaml', type=str, help="The data.yaml file to correct and parse") +parser.add_argument('name', type=str, help='The name of the dataset') -if sys.argv[-1] == 'data.py' or 'data.yaml' not in sys.argv[-1]: - print('Please provide a path to a data.yaml file') - exit(1) +args = parser.parse_args() -filePath = sys.argv[-1] +filePath = args.yaml +name = args.name data = [] basePath = os.path.dirname(filePath) trainPath = os.path.join(basePath, 'train', 'images') validPath = os.path.join(basePath, 'valid', 'images') testPath = os.path.join(basePath, 'test', 'images') +jsonPath = os.path.join(basePath, f'{name}.json') if not os.path.exists(basePath): print(f'Dataset directory {basePath} doesn\'t exist') @@ -68,3 +72,50 @@ with tempfile.NamedTemporaryFile() as tmp: with open(tmp.name, 'w') as t: t.writelines(data) shutil.copy(tmp.name, filePath) + +jsonObj = { + 'processor' : { + 'type' : 'category', + 'arity' : 1 + }, + 'config' : { + 'model' : name, + 'type' : 'onnx', + 'backend' : 'CPU', + 'input_size' : 640, + 'conf_thresh' : 0.3, + 'score_thresh' : 0.3, + 'nms_thresh' : 0.3 + }, + 'predicates' : [] +} + +with open(filePath, 'r') as file: + lines = file.readlines() + for line in lines: + if line[0:4] == 'nc: ' : + parts = line.strip().split(': ') + jsonObj['config']['classes'] = int(parts[1]) + else if line[0:7] == 'names: ' : + classesLine = line.strip().replace('names: ', '').replace('[', '').replace(']', '') + classes = classesLine.split(',') + i = 0 + for cl in classes : + className = cl.replace("'", '').strip() + jsonObj['predicates'].append({ + 'name' : className, + 'classID' : i + }) + i += 1 + +if jsonObj['config']['classes'] == len(jsonObj['predicates']) : + print('JSON looks good') +else : + print('JSON is questionable') + +formatted = json.dumps(jsonObj, indent=4) +print(formatted) +with open(jsonPath, 'w') as outfile: + outfile.write(formatted) + +