Create a formated json object when running scripts/data.py
This commit is contained in:
parent
f9b69a1b26
commit
1acbcfbe91
|
@ -3,19 +3,23 @@ import tempfile
|
||||||
import os.path
|
import os.path
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
|
||||||
print(sys.argv[-1])
|
parser = argparse.ArgumentParser(description='Tool for fixing data.yaml and creating config json')
|
||||||
|
parser.add_argument('yaml', type=str, help="The data.yaml file to correct and parse")
|
||||||
|
parser.add_argument('name', type=str, help='The name of the dataset')
|
||||||
|
|
||||||
if sys.argv[-1] == 'data.py' or 'data.yaml' not in sys.argv[-1]:
|
args = parser.parse_args()
|
||||||
print('Please provide a path to a data.yaml file')
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
filePath = sys.argv[-1]
|
filePath = args.yaml
|
||||||
|
name = args.name
|
||||||
data = []
|
data = []
|
||||||
basePath = os.path.dirname(filePath)
|
basePath = os.path.dirname(filePath)
|
||||||
trainPath = os.path.join(basePath, 'train', 'images')
|
trainPath = os.path.join(basePath, 'train', 'images')
|
||||||
validPath = os.path.join(basePath, 'valid', 'images')
|
validPath = os.path.join(basePath, 'valid', 'images')
|
||||||
testPath = os.path.join(basePath, 'test', 'images')
|
testPath = os.path.join(basePath, 'test', 'images')
|
||||||
|
jsonPath = os.path.join(basePath, f'{name}.json')
|
||||||
|
|
||||||
if not os.path.exists(basePath):
|
if not os.path.exists(basePath):
|
||||||
print(f'Dataset directory {basePath} doesn\'t exist')
|
print(f'Dataset directory {basePath} doesn\'t exist')
|
||||||
|
@ -68,3 +72,50 @@ with tempfile.NamedTemporaryFile() as tmp:
|
||||||
with open(tmp.name, 'w') as t:
|
with open(tmp.name, 'w') as t:
|
||||||
t.writelines(data)
|
t.writelines(data)
|
||||||
shutil.copy(tmp.name, filePath)
|
shutil.copy(tmp.name, filePath)
|
||||||
|
|
||||||
|
jsonObj = {
|
||||||
|
'processor' : {
|
||||||
|
'type' : 'category',
|
||||||
|
'arity' : 1
|
||||||
|
},
|
||||||
|
'config' : {
|
||||||
|
'model' : name,
|
||||||
|
'type' : 'onnx',
|
||||||
|
'backend' : 'CPU',
|
||||||
|
'input_size' : 640,
|
||||||
|
'conf_thresh' : 0.3,
|
||||||
|
'score_thresh' : 0.3,
|
||||||
|
'nms_thresh' : 0.3
|
||||||
|
},
|
||||||
|
'predicates' : []
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(filePath, 'r') as file:
|
||||||
|
lines = file.readlines()
|
||||||
|
for line in lines:
|
||||||
|
if line[0:4] == 'nc: ' :
|
||||||
|
parts = line.strip().split(': ')
|
||||||
|
jsonObj['config']['classes'] = int(parts[1])
|
||||||
|
else if line[0:7] == 'names: ' :
|
||||||
|
classesLine = line.strip().replace('names: ', '').replace('[', '').replace(']', '')
|
||||||
|
classes = classesLine.split(',')
|
||||||
|
i = 0
|
||||||
|
for cl in classes :
|
||||||
|
className = cl.replace("'", '').strip()
|
||||||
|
jsonObj['predicates'].append({
|
||||||
|
'name' : className,
|
||||||
|
'classID' : i
|
||||||
|
})
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
if jsonObj['config']['classes'] == len(jsonObj['predicates']) :
|
||||||
|
print('JSON looks good')
|
||||||
|
else :
|
||||||
|
print('JSON is questionable')
|
||||||
|
|
||||||
|
formatted = json.dumps(jsonObj, indent=4)
|
||||||
|
print(formatted)
|
||||||
|
with open(jsonPath, 'w') as outfile:
|
||||||
|
outfile.write(formatted)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue