Create a formated json object when running scripts/data.py

This commit is contained in:
mmcwilliams 2023-12-15 12:01:10 -05:00
parent f9b69a1b26
commit 1acbcfbe91
1 changed files with 56 additions and 5 deletions

View File

@ -3,19 +3,23 @@ import tempfile
import os.path import os.path
import os import os
import shutil import shutil
import argparse
import json
print(sys.argv[-1]) parser = argparse.ArgumentParser(description='Tool for fixing data.yaml and creating config json')
parser.add_argument('yaml', type=str, help="The data.yaml file to correct and parse")
parser.add_argument('name', type=str, help='The name of the dataset')
if sys.argv[-1] == 'data.py' or 'data.yaml' not in sys.argv[-1]: args = parser.parse_args()
print('Please provide a path to a data.yaml file')
exit(1)
filePath = sys.argv[-1] filePath = args.yaml
name = args.name
data = [] data = []
basePath = os.path.dirname(filePath) basePath = os.path.dirname(filePath)
trainPath = os.path.join(basePath, 'train', 'images') trainPath = os.path.join(basePath, 'train', 'images')
validPath = os.path.join(basePath, 'valid', 'images') validPath = os.path.join(basePath, 'valid', 'images')
testPath = os.path.join(basePath, 'test', 'images') testPath = os.path.join(basePath, 'test', 'images')
jsonPath = os.path.join(basePath, f'{name}.json')
if not os.path.exists(basePath): if not os.path.exists(basePath):
print(f'Dataset directory {basePath} doesn\'t exist') print(f'Dataset directory {basePath} doesn\'t exist')
@ -68,3 +72,50 @@ with tempfile.NamedTemporaryFile() as tmp:
with open(tmp.name, 'w') as t: with open(tmp.name, 'w') as t:
t.writelines(data) t.writelines(data)
shutil.copy(tmp.name, filePath) shutil.copy(tmp.name, filePath)
jsonObj = {
'processor' : {
'type' : 'category',
'arity' : 1
},
'config' : {
'model' : name,
'type' : 'onnx',
'backend' : 'CPU',
'input_size' : 640,
'conf_thresh' : 0.3,
'score_thresh' : 0.3,
'nms_thresh' : 0.3
},
'predicates' : []
}
with open(filePath, 'r') as file:
lines = file.readlines()
for line in lines:
if line[0:4] == 'nc: ' :
parts = line.strip().split(': ')
jsonObj['config']['classes'] = int(parts[1])
else if line[0:7] == 'names: ' :
classesLine = line.strip().replace('names: ', '').replace('[', '').replace(']', '')
classes = classesLine.split(',')
i = 0
for cl in classes :
className = cl.replace("'", '').strip()
jsonObj['predicates'].append({
'name' : className,
'classID' : i
})
i += 1
if jsonObj['config']['classes'] == len(jsonObj['predicates']) :
print('JSON looks good')
else :
print('JSON is questionable')
formatted = json.dumps(jsonObj, indent=4)
print(formatted)
with open(jsonPath, 'w') as outfile:
outfile.write(formatted)