Create a formated json object when running scripts/data.py
This commit is contained in:
parent
f9b69a1b26
commit
1acbcfbe91
|
@ -3,19 +3,23 @@ import tempfile
|
|||
import os.path
|
||||
import os
|
||||
import shutil
|
||||
import argparse
|
||||
import json
|
||||
|
||||
print(sys.argv[-1])
|
||||
parser = argparse.ArgumentParser(description='Tool for fixing data.yaml and creating config json')
|
||||
parser.add_argument('yaml', type=str, help="The data.yaml file to correct and parse")
|
||||
parser.add_argument('name', type=str, help='The name of the dataset')
|
||||
|
||||
if sys.argv[-1] == 'data.py' or 'data.yaml' not in sys.argv[-1]:
|
||||
print('Please provide a path to a data.yaml file')
|
||||
exit(1)
|
||||
args = parser.parse_args()
|
||||
|
||||
filePath = sys.argv[-1]
|
||||
filePath = args.yaml
|
||||
name = args.name
|
||||
data = []
|
||||
basePath = os.path.dirname(filePath)
|
||||
trainPath = os.path.join(basePath, 'train', 'images')
|
||||
validPath = os.path.join(basePath, 'valid', 'images')
|
||||
testPath = os.path.join(basePath, 'test', 'images')
|
||||
jsonPath = os.path.join(basePath, f'{name}.json')
|
||||
|
||||
if not os.path.exists(basePath):
|
||||
print(f'Dataset directory {basePath} doesn\'t exist')
|
||||
|
@ -68,3 +72,50 @@ with tempfile.NamedTemporaryFile() as tmp:
|
|||
with open(tmp.name, 'w') as t:
|
||||
t.writelines(data)
|
||||
shutil.copy(tmp.name, filePath)
|
||||
|
||||
jsonObj = {
|
||||
'processor' : {
|
||||
'type' : 'category',
|
||||
'arity' : 1
|
||||
},
|
||||
'config' : {
|
||||
'model' : name,
|
||||
'type' : 'onnx',
|
||||
'backend' : 'CPU',
|
||||
'input_size' : 640,
|
||||
'conf_thresh' : 0.3,
|
||||
'score_thresh' : 0.3,
|
||||
'nms_thresh' : 0.3
|
||||
},
|
||||
'predicates' : []
|
||||
}
|
||||
|
||||
with open(filePath, 'r') as file:
|
||||
lines = file.readlines()
|
||||
for line in lines:
|
||||
if line[0:4] == 'nc: ' :
|
||||
parts = line.strip().split(': ')
|
||||
jsonObj['config']['classes'] = int(parts[1])
|
||||
else if line[0:7] == 'names: ' :
|
||||
classesLine = line.strip().replace('names: ', '').replace('[', '').replace(']', '')
|
||||
classes = classesLine.split(',')
|
||||
i = 0
|
||||
for cl in classes :
|
||||
className = cl.replace("'", '').strip()
|
||||
jsonObj['predicates'].append({
|
||||
'name' : className,
|
||||
'classID' : i
|
||||
})
|
||||
i += 1
|
||||
|
||||
if jsonObj['config']['classes'] == len(jsonObj['predicates']) :
|
||||
print('JSON looks good')
|
||||
else :
|
||||
print('JSON is questionable')
|
||||
|
||||
formatted = json.dumps(jsonObj, indent=4)
|
||||
print(formatted)
|
||||
with open(jsonPath, 'w') as outfile:
|
||||
outfile.write(formatted)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue