yolo_web/scripts/data.py

122 lines
3.0 KiB
Python
Raw Permalink Normal View History

2023-06-29 03:29:45 +00:00
import sys
import tempfile
import os.path
import os
import shutil
import argparse
import json
2023-06-29 03:29:45 +00:00
parser = argparse.ArgumentParser(description='Tool for fixing data.yaml and creating config json')
parser.add_argument('yaml', type=str, help="The data.yaml file to correct and parse")
parser.add_argument('name', type=str, help='The name of the dataset')
2023-06-29 03:29:45 +00:00
args = parser.parse_args()
2023-06-29 03:29:45 +00:00
filePath = args.yaml
name = args.name
2023-06-29 03:29:45 +00:00
data = []
basePath = os.path.dirname(filePath)
trainPath = os.path.join(basePath, 'train', 'images')
validPath = os.path.join(basePath, 'valid', 'images')
testPath = os.path.join(basePath, 'test', 'images')
jsonPath = os.path.join(basePath, f'{name}.json')
2023-06-29 03:29:45 +00:00
if not os.path.exists(basePath):
print(f'Dataset directory {basePath} doesn\'t exist')
exit(2)
if not os.path.exists(trainPath):
print(f'Dataset directory {trainPath} doesn\'t exist')
exit(2)
if not os.path.exists(validPath):
print(f'Dataset directory {validPath} doesn\'t exist')
exit(2)
if not os.path.exists(testPath):
print(f'Dataset directory {testPath} doesn\'t exist')
exit(2)
def isInList (match, l) :
for i in l:
if match in i:
return True
return False
with open(filePath, 'r') as file:
lines = file.readlines()
for line in lines:
if 'train: ' in line:
data.append(f'{os.linesep}')
2023-06-29 03:29:45 +00:00
data.append(f'train: {trainPath}{os.linesep}')
elif 'val: ' in line:
data.append(f'{os.linesep}')
2023-06-29 03:29:45 +00:00
data.append(f'val: {validPath}{os.linesep}')
elif 'test: ' in line:
data.append(f'{os.linesep}')
2023-06-29 03:29:45 +00:00
data.append(f'test: {testPath}{os.linesep}')
else:
data.append(line)
if not isInList('train: ', data):
data.append(f'{os.linesep}')
2023-06-29 03:29:45 +00:00
data.append(f'train: {trainPath}{os.linesep}')
elif not isInList('val: ', data):
data.append(f'{os.linesep}')
2023-06-29 03:29:45 +00:00
data.append(f'val: {validPath}{os.linesep}')
elif not isInList('test: ', data):
data.append(f'{os.linesep}')
2023-06-29 03:29:45 +00:00
data.append(f'test: {testPath}{os.linesep}')
file.close()
with tempfile.NamedTemporaryFile() as tmp:
with open(tmp.name, 'w') as t:
t.writelines(data)
shutil.copy(tmp.name, filePath)
jsonObj = {
'processor' : {
'type' : 'category',
'arity' : 1
},
'config' : {
'model' : name,
'type' : 'onnx',
'backend' : 'CPU',
'input_size' : 640,
'conf_thresh' : 0.3,
'score_thresh' : 0.3,
'nms_thresh' : 0.3
},
'predicates' : []
}
with open(filePath, 'r') as file:
lines = file.readlines()
for line in lines:
if line[0:4] == 'nc: ' :
parts = line.strip().split(': ')
jsonObj['config']['classes'] = int(parts[1])
2023-12-15 17:03:34 +00:00
elif line[0:7] == 'names: ' :
classesLine = line.strip().replace('names: ', '').replace('[', '').replace(']', '')
classes = classesLine.split(',')
i = 0
for cl in classes :
className = cl.replace("'", '').strip()
jsonObj['predicates'].append({
'name' : className,
'classID' : i
})
i += 1
if jsonObj['config']['classes'] == len(jsonObj['predicates']) :
print('JSON looks good')
else :
print('JSON is questionable')
formatted = json.dumps(jsonObj, indent=4)
print(formatted)
with open(jsonPath, 'w') as outfile:
outfile.write(formatted)