Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
109 commits
Select commit Hold shift + click to select a range
136a767
code to produce models
Dec 5, 2019
d4bc814
add files needed for training
Dec 5, 2019
038d776
add files needed for training
Dec 5, 2019
cd1990d
Update config_params.json
Dec 5, 2019
a216dcc
Update README
Dec 5, 2019
036e2e9
Update README
Dec 5, 2019
f7a5a57
Update README
Dec 5, 2019
bbe6f99
Update README
Dec 5, 2019
e4013fe
Delete README
Dec 5, 2019
f69d445
Add new file
Dec 5, 2019
4897fd3
📝 howto: Be more verbose with the subtree pull
mikegerber Dec 9, 2019
0cddfff
Update README
vahidrezanezhad Dec 10, 2019
c5e1e2d
Update README.md
vahidrezanezhad Dec 10, 2019
3ac99b4
Merge commit 'c5e1e2dda7542c6d8a9787fa496b538ce8519794'
Dec 10, 2019
bb212da
Update main.py
vahidrezanezhad Dec 10, 2019
2e768e4
Add LICENSE
cneud Dec 10, 2019
8bdb295
Merge pull request #2 from cneud/add-license-1
cneud Jan 15, 2020
d2a8119
Update README.md
cneud Jan 15, 2020
a9c86b2
Update README.md
cneud Jan 15, 2020
5b4df66
Merge pull request #7 from qurator-spk/update-readme
vahidrezanezhad Jan 16, 2020
7063789
Update README.md
vahidrezanezhad Jan 16, 2020
63fcb96
Update README.md
vahidrezanezhad Jan 16, 2020
5fb7552
first updates, padding, rotations
Jun 22, 2021
4bea9fd
continue training, losses and etc
Jun 22, 2021
75dc5f3
Merge pull request #15 from vahidrezanezhad/master
vahidrezanezhad Jun 22, 2021
040d3cf
Update README.md
vahidrezanezhad Jun 23, 2021
e698463
Update README.md
vahidrezanezhad Jun 23, 2021
3ec551d
Update README.md
vahidrezanezhad Jun 23, 2021
57f8827
Update README.md
vahidrezanezhad Jun 23, 2021
9221b6c
Update README.md
vahidrezanezhad Jun 23, 2021
5cd47a8
Update README.md
vahidrezanezhad Jun 29, 2021
11ae468
Update README.md
vahidrezanezhad Jun 29, 2021
d7e265e
Update README.md
vahidrezanezhad Jun 29, 2021
31b0102
Update README.md
vahidrezanezhad Jun 29, 2021
da55806
Update README.md
vahidrezanezhad Jul 14, 2021
dbb4040
supposed to solve https://github.com/qurator-spk/sbb_binarization/iss…
Aug 22, 2022
522f00a
adjusting to tf2
vahidrezanezhad Apr 4, 2024
7dfaafe
adding requirements
vahidrezanezhad Apr 4, 2024
c4bcfc1
use headless cv2
cneud Apr 10, 2024
0103e14
add info on helpful tools (fix #14)
cneud Apr 10, 2024
5f84938
update parameter config docs (fix #11)
cneud Apr 10, 2024
02b1436
code formatting with black; typos
cneud Apr 10, 2024
d27647a
first working update of branch
vahidrezanezhad Apr 15, 2024
dbb8450
integrating first working classification training model
vahidrezanezhad Apr 29, 2024
38db3e9
adding enhancement training
vahidrezanezhad May 6, 2024
8d1050e
inference script is added
vahidrezanezhad May 7, 2024
ce1108a
modifications
vahidrezanezhad May 7, 2024
a7e1f25
Update train.py
vahidrezanezhad May 8, 2024
45aba32
Update utils.py
vahidrezanezhad May 12, 2024
6ef8658
adding page xml to label generator
vahidrezanezhad May 16, 2024
2623113
page to label enable textline new concept
vahidrezanezhad May 17, 2024
5f06a02
update requirements
vahidrezanezhad May 17, 2024
f7dda07
page2label with a dynamic layout
vahidrezanezhad May 22, 2024
d687f53
dynamic layout decorated with artificial class on text elements boundry
vahidrezanezhad May 23, 2024
947a0e0
missing text types are added
vahidrezanezhad May 23, 2024
4b7f7da
use cases like textline, word and glyph are added
vahidrezanezhad May 23, 2024
f574601
use case printspace is added
vahidrezanezhad May 23, 2024
bf14683
machine based reading order training dataset generator is added
vahidrezanezhad May 24, 2024
4e4490d
machine based reading order training is integrated
vahidrezanezhad May 24, 2024
5aa6ee0
adding rest_as_paragraph and rest_as_graphic to elements
vahidrezanezhad May 27, 2024
29ddd4d
pass degrading scales for image enhancement as a json file
vahidrezanezhad May 28, 2024
356da4c
min area size of text region passes as an argument for machine based …
vahidrezanezhad May 28, 2024
2e7c69f
inference for reading order
vahidrezanezhad May 28, 2024
f6abefb
reading order detection on xml with layout + result will be written i…
vahidrezanezhad May 29, 2024
7850335
min_area size of regions considered for reading order detection passe…
vahidrezanezhad May 29, 2024
4640d9f
modifying xml parsing
vahidrezanezhad May 30, 2024
821290c
scaling and cropping of labels and org images
vahidrezanezhad May 30, 2024
b9cbc0e
replacement in a list done correctly
vahidrezanezhad Jun 6, 2024
e25a925
Update README.md
vahidrezanezhad Jun 6, 2024
1c8873f
just defined textregion types can be extracted as label
vahidrezanezhad Jun 6, 2024
b1d971a
just defined textregion types can be extracted as label
vahidrezanezhad Jun 6, 2024
dc356a5
just defined graphic region types can be extracted as label
vahidrezanezhad Jun 6, 2024
815e5a1
updating train.py
vahidrezanezhad Jun 7, 2024
41a0e15
updating train.py nontransformer backend
vahidrezanezhad Jun 10, 2024
2aa216e
binarization as a separate task of segmentation
vahidrezanezhad Jun 11, 2024
f1fd74c
transformer patch size is dynamic now.
vahidrezanezhad Jun 12, 2024
743f2e9
Transformer+CNN structure is added to vision transformer type
vahidrezanezhad Jun 12, 2024
9358657
update config
vahidrezanezhad Jun 12, 2024
033cf67
update reading order machine based
vahidrezanezhad Jun 21, 2024
c0faece
update inference
vahidrezanezhad Jun 21, 2024
647a3f8
resolving typo
vahidrezanezhad Jul 9, 2024
55f3cb9
printspace_as_class_in_layout is integrated. Printspace can be define…
vahidrezanezhad Jul 16, 2024
9521768
adding degrading and brightness augmentation to no patches case training
vahidrezanezhad Jul 17, 2024
f2692cf
brightness augmentation modified
Jul 17, 2024
c340fbb
increasing margin in the case of pixelwise inference
Jul 23, 2024
30894dd
erosion and dilation parameters are changed & separators are written …
vahidrezanezhad Jul 24, 2024
5fbe941
inference updated
vahidrezanezhad Jul 24, 2024
59e5892
erosion rate changed
vahidrezanezhad Aug 1, 2024
b6bdf94
add documentation from wiki as markdown file to the codebase
cneud Aug 8, 2024
f4bad09
save only layout output. different from overlayed layout on image
vahidrezanezhad Aug 9, 2024
85dd59f
update
vahidrezanezhad Aug 9, 2024
7be326d
augmentation function for red textlines, rgb background and scaling f…
vahidrezanezhad Aug 20, 2024
95bbdf8
updating augmentations
vahidrezanezhad Aug 21, 2024
f31219b
scaling, channels shuffling, rgb background and red content added to …
vahidrezanezhad Aug 21, 2024
9904846
using prepared binarized images in the case of augmentation
vahidrezanezhad Aug 22, 2024
4f0e3ef
early dilation for textline artificial class
vahidrezanezhad Aug 27, 2024
c502e67
adding foreground rgb to augmentation
vahidrezanezhad Aug 28, 2024
5f456cf
fixing artificial class bug
vahidrezanezhad Aug 28, 2024
cca4d17
new augmentations for patchwise training
vahidrezanezhad Aug 30, 2024
df4a47a
Update inference.py to check if save_layout was passed as argument ot…
johnlockejrr Oct 19, 2024
451188c
Changed deprecated `lr` to `learning_rate` and `model.fit_generator` …
johnlockejrr Oct 19, 2024
be57f13
Update utils.py
johnlockejrr May 11, 2025
102b04c
Update utils.py
johnlockejrr May 11, 2025
1bf8019
Update gt_gen_utils.py
johnlockejrr May 14, 2025
7661080
LR Warmup and Optimization Implementation
johnlockejrr May 17, 2025
f298643
Fix `ReduceONPlateau` wrong logic
johnlockejrr May 17, 2025
30fe51f
move src/.../train.py to root to accomodate old PR
kba Oct 16, 2025
54132a4
Merge remote-tracking branch 'pixelwise_local/ReduceLROnPlateau' into…
kba Oct 16, 2025
ad53ea3
move train.py back
kba Oct 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added .gitkeep
Empty file.
Empty file added __init__.py
Empty file.
158 changes: 121 additions & 37 deletions src/eynollah/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import tensorflow as tf
from tensorflow.compat.v1.keras.backend import set_session
from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, Callback, ModelCheckpoint
from sacred import Experiment
from tensorflow.keras.models import load_model
from tqdm import tqdm
Expand Down Expand Up @@ -61,6 +62,20 @@ def on_train_batch_end(self, batch, logs=None):
json.dump(self._config, fp) # encode dict into JSON
print(f"saved model as steps {self.step_count} to {save_file}")

def get_warmup_schedule(start_lr, target_lr, warmup_epochs, steps_per_epoch):
initial_learning_rate = start_lr
target_learning_rate = target_lr
warmup_steps = warmup_epochs * steps_per_epoch

lr_schedule = tf.keras.optimizers.schedules.LinearSchedule(
initial_learning_rate=initial_learning_rate,
final_learning_rate=target_learning_rate,
total_steps=warmup_steps
)

return lr_schedule



def configuration():
config = tf.compat.v1.ConfigProto()
Expand All @@ -80,7 +95,6 @@ def get_dirs_or_files(input_data):

ex = Experiment(save_git_info=False)


@ex.config
def config_params():
n_classes = None # Number of classes. In the case of binary classification this should be 2.
Expand Down Expand Up @@ -145,6 +159,19 @@ def config_params():
number_of_backgrounds_per_image = 1
dir_rgb_backgrounds = None
dir_rgb_foregrounds = None
reduce_lr_enabled = False # Whether to use ReduceLROnPlateau callback
reduce_lr_monitor = 'val_loss' # Metric to monitor for reducing learning rate
reduce_lr_factor = 0.5 # Factor to reduce learning rate by
reduce_lr_patience = 3 # Number of epochs to wait before reducing learning rate
reduce_lr_min_lr = 1e-6 # Minimum learning rate
reduce_lr_min_delta = 1e-5 # Minimum change in monitored value to be considered as improvement
early_stopping_enabled = False # Whether to use EarlyStopping callback
early_stopping_monitor = 'val_loss' # Metric to monitor for early stopping
early_stopping_patience = 10 # Number of epochs to wait before stopping
early_stopping_restore_best_weights = True # Whether to restore best weights when stopping
warmup_enabled = False # Whether to use learning rate warmup
warmup_epochs = 5 # Number of epochs for warmup
warmup_start_lr = 1e-6 # Starting learning rate for warmup

@ex.automain
def run(_config, n_classes, n_epochs, input_height,
Expand All @@ -159,7 +186,10 @@ def run(_config, n_classes, n_epochs, input_height,
transformer_mlp_head_units, transformer_layers, transformer_num_heads, transformer_cnn_first,
transformer_patchsize_x, transformer_patchsize_y,
transformer_num_patches_xy, backbone_type, save_interval, flip_index, dir_eval, dir_output,
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds):
pretraining, learning_rate, task, f1_threshold_classification, classification_classes_name, dir_img_bin, number_of_backgrounds_per_image,dir_rgb_backgrounds, dir_rgb_foregrounds,
reduce_lr_enabled, reduce_lr_monitor, reduce_lr_factor, reduce_lr_patience, reduce_lr_min_lr, reduce_lr_min_delta,
early_stopping_enabled, early_stopping_monitor, early_stopping_patience, early_stopping_restore_best_weights,
warmup_enabled, warmup_epochs, warmup_start_lr):

if dir_rgb_backgrounds:
list_all_possible_background_images = os.listdir(dir_rgb_backgrounds)
Expand Down Expand Up @@ -320,20 +350,91 @@ def run(_config, n_classes, n_epochs, input_height,
#if you want to see the model structure just uncomment model summary.
model.summary()

# Create callbacks list
callbacks = []
if reduce_lr_enabled:
reduce_lr = ReduceLROnPlateau(
monitor=reduce_lr_monitor,
factor=reduce_lr_factor,
patience=reduce_lr_patience,
min_lr=reduce_lr_min_lr,
min_delta=reduce_lr_min_delta,
verbose=1
)
callbacks.append(reduce_lr)

if early_stopping_enabled:
early_stopping = EarlyStopping(
monitor=early_stopping_monitor,
patience=early_stopping_patience,
restore_best_weights=early_stopping_restore_best_weights,
verbose=1
)
callbacks.append(early_stopping)

# Add checkpoint to save models every epoch
class ModelCheckpointWithConfig(ModelCheckpoint):
def __init__(self, *args, **kwargs):
self._config = _config
super().__init__(*args, **kwargs)

def on_epoch_end(self, epoch, logs=None):
super().on_epoch_end(epoch, logs)
model_dir = os.path.join(dir_output, f"model_{epoch+1}")
with open(os.path.join(model_dir, "config.json"), "w") as fp:
json.dump(self._config, fp)

checkpoint_epoch = ModelCheckpointWithConfig(
os.path.join(dir_output, "model_{epoch}"),
save_freq='epoch',
save_weights_only=False,
save_best_only=False,
verbose=1
)
callbacks.append(checkpoint_epoch)

# Calculate steps per epoch
steps_per_epoch = int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1

# Create optimizer with or without warmup
if warmup_enabled:
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=warmup_start_lr,
decay_steps=warmup_epochs * steps_per_epoch,
end_learning_rate=learning_rate,
power=1.0 # Linear decay
)
optimizer = Adam(learning_rate=lr_schedule)
else:
optimizer = Adam(learning_rate=learning_rate)

if (task == "segmentation" or task == "binarization"):
if not is_loss_soft_dice and not weighted_loss:
model.compile(loss='categorical_crossentropy',
optimizer=optimizer, metrics=['accuracy'])
if is_loss_soft_dice:
model.compile(loss=soft_dice_loss,
optimizer=optimizer, metrics=['accuracy'])
if weighted_loss:
model.compile(loss=weighted_categorical_crossentropy(weights),
optimizer=optimizer, metrics=['accuracy'])
elif task == "enhancement":
model.compile(loss='mean_squared_error',
optimizer=optimizer, metrics=['accuracy'])

if task == "segmentation" or task == "binarization":
if not is_loss_soft_dice and not weighted_loss:
model.compile(loss='categorical_crossentropy',
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
optimizer=optimizer, metrics=['accuracy'])
if is_loss_soft_dice:
model.compile(loss=soft_dice_loss,
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
optimizer=optimizer, metrics=['accuracy'])
if weighted_loss:
model.compile(loss=weighted_categorical_crossentropy(weights),
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
optimizer=optimizer, metrics=['accuracy'])
elif task == "enhancement":
model.compile(loss='mean_squared_error',
optimizer=Adam(learning_rate=learning_rate), metrics=['accuracy'])
optimizer=optimizer, metrics=['accuracy'])


# generating train and evaluation data
Expand All @@ -342,39 +443,22 @@ def run(_config, n_classes, n_epochs, input_height,
val_gen = data_gen(dir_flow_eval_imgs, dir_flow_eval_labels, batch_size=n_batch,
input_height=input_height, input_width=input_width, n_classes=n_classes, task=task)

##img_validation_patches = os.listdir(dir_flow_eval_imgs)
##score_best=[]
##score_best.append(0)
# Single fit call with all epochs
history = model.fit(
train_gen,
steps_per_epoch=steps_per_epoch,
validation_data=val_gen,
validation_steps=1,
epochs=n_epochs,
callbacks=callbacks
)

if save_interval:
save_weights_callback = SaveWeightsAfterSteps(save_interval, dir_output, _config)


for i in tqdm(range(index_start, n_epochs + index_start)):
if save_interval:
model.fit(
train_gen,
steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1,
validation_data=val_gen,
validation_steps=1,
epochs=1, callbacks=[save_weights_callback])
else:
model.fit(
train_gen,
steps_per_epoch=int(len(os.listdir(dir_flow_train_imgs)) / n_batch) - 1,
validation_data=val_gen,
validation_steps=1,
epochs=1)

model.save(os.path.join(dir_output,'model_'+str(i)))
# Save the best model (either from early stopping or final model)
model.save(os.path.join(dir_output, 'model_best'))

with open(os.path.join(os.path.join(dir_output,'model_'+str(i)),"config.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON

#os.system('rm -rf '+dir_train_flowing)
#os.system('rm -rf '+dir_eval_flowing)

#model.save(dir_output+'/'+'model'+'.h5')
with open(os.path.join(dir_output, 'model_best', "config.json"), "w") as fp:
json.dump(_config, fp) # encode dict into JSON

elif task=='classification':
configuration()
model = resnet50_classifier(n_classes, input_height, input_width, weight_decay, pretraining)
Expand Down
53 changes: 53 additions & 0 deletions train_no_patches_448x448.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
{
"backbone_type" : "nontransformer",
"task": "segmentation",
"n_classes" : 3,
"n_epochs" : 50,
"input_height" : 448,
"input_width" : 448,
"weight_decay" : 1e-4,
"n_batch" : 4,
"learning_rate": 2e-5,
"patches" : false,
"pretraining" : true,
"augmentation" : true,
"flip_aug" : false,
"blur_aug" : false,
"scaling" : true,
"degrading": true,
"brightening": false,
"binarization" : false,
"scaling_bluring" : false,
"scaling_binarization" : false,
"scaling_flip" : false,
"rotation": false,
"rotation_not_90": false,
"blur_k" : ["blur","guass","median"],
"scales" : [0.6, 0.7, 0.8, 0.9],
"brightness" : [1.3, 1.5, 1.7, 2],
"degrade_scales" : [0.2, 0.4],
"flip_index" : [0, 1, -1],
"thetha" : [10, -10],
"continue_training": false,
"index_start" : 0,
"dir_of_start_model" : " ",
"weighted_loss": false,
"is_loss_soft_dice": true,
"data_is_provided": true,
"dir_train": "/home/incognito/sbb_pixelwise_segmentation/dataset/sam_41_mss/dir_train/train",
"dir_eval": "/home/incognito/sbb_pixelwise_segmentation/dataset/sam_41_mss/dir_train/eval",
"dir_output": "runs/sam_41_mss_npt_448x448",
"reduce_lr_enabled": true,
"reduce_lr_monitor": "val_loss",
"reduce_lr_factor": 0.2,
"reduce_lr_patience": 3,
"reduce_lr_min_delta": 1e-5,
"reduce_lr_min_lr": 1e-6,
"early_stopping_enabled": true,
"early_stopping_monitor": "val_loss",
"early_stopping_patience": 6,
"early_stopping_restore_best_weights": true,
"warmup_enabled": true,
"warmup_epochs": 5,
"warmup_start_lr": 1e-6
}
Loading