122 lines
3.0 KiB
Python
122 lines
3.0 KiB
Python
import sys
|
|
import tempfile
|
|
import os.path
|
|
import os
|
|
import shutil
|
|
import argparse
|
|
import json
|
|
|
|
parser = argparse.ArgumentParser(description='Tool for fixing data.yaml and creating config json')
|
|
parser.add_argument('yaml', type=str, help="The data.yaml file to correct and parse")
|
|
parser.add_argument('name', type=str, help='The name of the dataset')
|
|
|
|
args = parser.parse_args()
|
|
|
|
filePath = args.yaml
|
|
name = args.name
|
|
data = []
|
|
basePath = os.path.dirname(filePath)
|
|
trainPath = os.path.join(basePath, 'train', 'images')
|
|
validPath = os.path.join(basePath, 'valid', 'images')
|
|
testPath = os.path.join(basePath, 'test', 'images')
|
|
jsonPath = os.path.join(basePath, f'{name}.json')
|
|
|
|
if not os.path.exists(basePath):
|
|
print(f'Dataset directory {basePath} doesn\'t exist')
|
|
exit(2)
|
|
|
|
if not os.path.exists(trainPath):
|
|
print(f'Dataset directory {trainPath} doesn\'t exist')
|
|
exit(2)
|
|
|
|
if not os.path.exists(validPath):
|
|
print(f'Dataset directory {validPath} doesn\'t exist')
|
|
exit(2)
|
|
|
|
if not os.path.exists(testPath):
|
|
print(f'Dataset directory {testPath} doesn\'t exist')
|
|
exit(2)
|
|
|
|
def isInList (match, l) :
|
|
for i in l:
|
|
if match in i:
|
|
return True
|
|
return False
|
|
|
|
with open(filePath, 'r') as file:
|
|
lines = file.readlines()
|
|
for line in lines:
|
|
if 'train: ' in line:
|
|
data.append(f'{os.linesep}')
|
|
data.append(f'train: {trainPath}{os.linesep}')
|
|
elif 'val: ' in line:
|
|
data.append(f'{os.linesep}')
|
|
data.append(f'val: {validPath}{os.linesep}')
|
|
elif 'test: ' in line:
|
|
data.append(f'{os.linesep}')
|
|
data.append(f'test: {testPath}{os.linesep}')
|
|
else:
|
|
data.append(line)
|
|
if not isInList('train: ', data):
|
|
data.append(f'{os.linesep}')
|
|
data.append(f'train: {trainPath}{os.linesep}')
|
|
elif not isInList('val: ', data):
|
|
data.append(f'{os.linesep}')
|
|
data.append(f'val: {validPath}{os.linesep}')
|
|
elif not isInList('test: ', data):
|
|
data.append(f'{os.linesep}')
|
|
data.append(f'test: {testPath}{os.linesep}')
|
|
file.close()
|
|
|
|
with tempfile.NamedTemporaryFile() as tmp:
|
|
with open(tmp.name, 'w') as t:
|
|
t.writelines(data)
|
|
shutil.copy(tmp.name, filePath)
|
|
|
|
jsonObj = {
|
|
'processor' : {
|
|
'type' : 'category',
|
|
'arity' : 1
|
|
},
|
|
'config' : {
|
|
'model' : name,
|
|
'type' : 'onnx',
|
|
'backend' : 'CPU',
|
|
'input_size' : 640,
|
|
'conf_thresh' : 0.3,
|
|
'score_thresh' : 0.3,
|
|
'nms_thresh' : 0.3
|
|
},
|
|
'predicates' : []
|
|
}
|
|
|
|
with open(filePath, 'r') as file:
|
|
lines = file.readlines()
|
|
for line in lines:
|
|
if line[0:4] == 'nc: ' :
|
|
parts = line.strip().split(': ')
|
|
jsonObj['config']['classes'] = int(parts[1])
|
|
elif line[0:7] == 'names: ' :
|
|
classesLine = line.strip().replace('names: ', '').replace('[', '').replace(']', '')
|
|
classes = classesLine.split(',')
|
|
i = 0
|
|
for cl in classes :
|
|
className = cl.replace("'", '').strip()
|
|
jsonObj['predicates'].append({
|
|
'name' : className,
|
|
'classID' : i
|
|
})
|
|
i += 1
|
|
|
|
if jsonObj['config']['classes'] == len(jsonObj['predicates']) :
|
|
print('JSON looks good')
|
|
else :
|
|
print('JSON is questionable')
|
|
|
|
formatted = json.dumps(jsonObj, indent=4)
|
|
print(formatted)
|
|
with open(jsonPath, 'w') as outfile:
|
|
outfile.write(formatted)
|
|
|
|
|