65 lines
1.6 KiB
Python
65 lines
1.6 KiB
Python
|
import sys
|
||
|
import tempfile
|
||
|
import os.path
|
||
|
import os
|
||
|
import shutil
|
||
|
|
||
|
print(sys.argv[-1])
|
||
|
|
||
|
if sys.argv[-1] == 'data.py' or 'data.yaml' not in sys.argv[-1]:
|
||
|
print('Please provide a path to a data.yaml file')
|
||
|
exit(1)
|
||
|
|
||
|
filePath = sys.argv[-1]
|
||
|
data = []
|
||
|
basePath = os.path.dirname(filePath)
|
||
|
trainPath = os.path.join(basePath, 'train', 'images')
|
||
|
validPath = os.path.join(basePath, 'valid', 'images')
|
||
|
testPath = os.path.join(basePath, 'test', 'images')
|
||
|
|
||
|
if not os.path.exists(basePath):
|
||
|
print(f'Dataset directory {basePath} doesn\'t exist')
|
||
|
exit(2)
|
||
|
|
||
|
if not os.path.exists(trainPath):
|
||
|
print(f'Dataset directory {trainPath} doesn\'t exist')
|
||
|
exit(2)
|
||
|
|
||
|
if not os.path.exists(validPath):
|
||
|
print(f'Dataset directory {validPath} doesn\'t exist')
|
||
|
exit(2)
|
||
|
|
||
|
if not os.path.exists(testPath):
|
||
|
print(f'Dataset directory {testPath} doesn\'t exist')
|
||
|
exit(2)
|
||
|
|
||
|
def isInList (match, l) :
|
||
|
for i in l:
|
||
|
if match in i:
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
with open(filePath, 'r') as file:
|
||
|
lines = file.readlines()
|
||
|
for line in lines:
|
||
|
if 'train: ' in line:
|
||
|
data.append(f'train: {trainPath}{os.linesep}')
|
||
|
elif 'val: ' in line:
|
||
|
data.append(f'val: {validPath}{os.linesep}')
|
||
|
elif 'test: ' in line:
|
||
|
data.append(f'test: {testPath}{os.linesep}')
|
||
|
else:
|
||
|
data.append(line)
|
||
|
if not isInList('train: ', data):
|
||
|
data.append(f'train: {trainPath}{os.linesep}')
|
||
|
elif not isInList('val: ', data):
|
||
|
data.append(f'val: {validPath}{os.linesep}')
|
||
|
elif not isInList('test: ', data):
|
||
|
data.append(f'test: {testPath}{os.linesep}')
|
||
|
file.close()
|
||
|
|
||
|
with tempfile.NamedTemporaryFile() as tmp:
|
||
|
with open(tmp.name, 'w') as t:
|
||
|
t.writelines(data)
|
||
|
shutil.copy(tmp.name, filePath)
|