From 3ff1fe1076203c51d2c4316882d1eb9f4091154d Mon Sep 17 00:00:00 2001 From: dazory Date: Tue, 20 Aug 2024 17:12:53 +0900 Subject: [PATCH 1/7] add oamix --- ...faster-rcnn_r50_fpn_1x_cityscapes_oamix.py | 69 ++++ mmdet/datasets/transforms/__init__.py | 7 +- mmdet/datasets/transforms/augment_wrappers.py | 4 +- mmdet/datasets/transforms/colorspace.py | 30 ++ mmdet/datasets/transforms/geometric.py | 351 ++++++++++++++++ mmdet/datasets/transforms/oa_mix.py | 385 ++++++++++++++++++ 6 files changed, 841 insertions(+), 5 deletions(-) create mode 100644 configs/oamix/faster-rcnn_r50_fpn_1x_cityscapes_oamix.py create mode 100644 mmdet/datasets/transforms/oa_mix.py diff --git a/configs/oamix/faster-rcnn_r50_fpn_1x_cityscapes_oamix.py b/configs/oamix/faster-rcnn_r50_fpn_1x_cityscapes_oamix.py new file mode 100644 index 00000000000..6517e3b83c1 --- /dev/null +++ b/configs/oamix/faster-rcnn_r50_fpn_1x_cityscapes_oamix.py @@ -0,0 +1,69 @@ +_base_ = [ + '../_base_/models/faster-rcnn_r50_fpn.py', + '../_base_/datasets/cityscapes_detection.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_1x.py' +] + + +# OA-Mix +oamix_config=dict( + type='OAMix', version='oamix', + box_scale=(0.05, 0.3), box_ratio=(3, 0.33), + sigma_ratio=0.2, score_thresh=10, +) + +backend_args = None +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args=backend_args), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='RandomResize', + scale=[(2048, 800), (2048, 1024)], + keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + oamix_config, + dict(type='PackDetInputs') +] +train_dataloader = dict( + num_workers=8, + dataset=dict(dataset=dict(pipeline=train_pipeline)) +) + + +# Model +model = dict( + backbone=dict(init_cfg=None), + roi_head=dict( + bbox_head=dict( + num_classes=8, + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)))) + +# optimizer +# lr is set for a batch size of 8 +optim_wrapper = dict(optimizer=dict(lr=0.01)) + +# learning rate +param_scheduler = [ + dict( + type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500), + dict( + type='MultiStepLR', + begin=0, + end=8, + by_epoch=True, + # [7] yields higher performance than [6] + milestones=[7], + gamma=0.1) +] + +# actual epoch = 8 * 8 = 64 +train_cfg = dict(max_epochs=8) + +# For better, more stable performance initialize from COCO +load_from = 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth' # noqa + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# USER SHOULD NOT CHANGE ITS VALUES. +# base_batch_size = (8 GPUs) x (1 samples per GPU) +# TODO: support auto scaling lr +# auto_scale_lr = dict(base_batch_size=8) diff --git a/mmdet/datasets/transforms/__init__.py b/mmdet/datasets/transforms/__init__.py index ab3478feb00..8e4ba6d2880 100644 --- a/mmdet/datasets/transforms/__init__.py +++ b/mmdet/datasets/transforms/__init__.py @@ -2,13 +2,14 @@ from .augment_wrappers import AutoAugment, RandAugment from .colorspace import (AutoContrast, Brightness, Color, ColorTransform, Contrast, Equalize, Invert, Posterize, Sharpness, - Solarize, SolarizeAdd) + Solarize, SolarizeAdd, Invert4Mix) from .formatting import (ImageToTensor, PackDetInputs, PackReIDInputs, PackTrackInputs, ToTensor, Transpose) from .frame_sampling import BaseFrameSample, UniformRefFrameSample from .geometric import (GeomTransform, Rotate, ShearX, ShearY, TranslateX, TranslateY) from .instaboost import InstaBoost +from .oa_mix import OAMix from .loading import (FilterAnnotations, InferencerLoader, LoadAnnotations, LoadEmptyAnnotations, LoadImageFromNDArray, LoadMultiChannelImageFromFiles, LoadPanopticAnnotations, @@ -35,11 +36,11 @@ 'Mosaic', 'MixUp', 'RandomAffine', 'YOLOXHSVRandomAug', 'CopyPaste', 'FilterAnnotations', 'Pad', 'GeomTransform', 'ColorTransform', 'RandAugment', 'Sharpness', 'Solarize', 'SolarizeAdd', 'Posterize', - 'AutoContrast', 'Invert', 'MultiBranch', 'RandomErasing', + 'AutoContrast', 'Invert', 'Invert4Mix', 'MultiBranch', 'RandomErasing', 'LoadEmptyAnnotations', 'RandomOrder', 'CachedMosaic', 'CachedMixUp', 'FixShapeResize', 'ProposalBroadcaster', 'InferencerLoader', 'LoadTrackAnnotations', 'BaseFrameSample', 'UniformRefFrameSample', 'PackTrackInputs', 'PackReIDInputs', 'FixScaleResize', 'ResizeShortestEdge', 'GTBoxSubOne_GLIP', 'RandomFlip_GLIP', - 'RandomSamplingNegPos', 'LoadTextAnnotations' + 'RandomSamplingNegPos', 'LoadTextAnnotations' 'OAMix', ] diff --git a/mmdet/datasets/transforms/augment_wrappers.py b/mmdet/datasets/transforms/augment_wrappers.py index 19fae6efdf6..d796b561159 100644 --- a/mmdet/datasets/transforms/augment_wrappers.py +++ b/mmdet/datasets/transforms/augment_wrappers.py @@ -77,9 +77,9 @@ def level_to_mag(level: Optional[int], min_mag: float, max_mag: float) -> float: """Map from level to magnitude.""" if level is None: - return round(np.random.rand() * (max_mag - min_mag) + min_mag, 1) + return round(np.random.rand() * (max_mag - min_mag) + min_mag, 2) else: - return round(level / _MAX_LEVEL * (max_mag - min_mag) + min_mag, 1) + return round(level / _MAX_LEVEL * (max_mag - min_mag) + min_mag, 2) @TRANSFORMS.register_module() diff --git a/mmdet/datasets/transforms/colorspace.py b/mmdet/datasets/transforms/colorspace.py index e0ba2e97c7e..4b830fe560a 100644 --- a/mmdet/datasets/transforms/colorspace.py +++ b/mmdet/datasets/transforms/colorspace.py @@ -491,3 +491,33 @@ def _transform_img(self, results: dict, mag: float) -> None: """Invert the image.""" img = results['img'] results['img'] = mmcv.iminvert(img).astype(img.dtype) + + +@TRANSFORMS.register_module() +class Invert4Mix(ColorTransform): + """Invert and translate images. + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + prob (float): The probability for performing invert therefore should + be in range [0, 1]. Defaults to 1.0. + level (int, optional): No use for Invert transformation. + Defaults to None. + min_mag (float): No use for Invert transformation. Defaults to 0.1. + max_mag (float): No use for Invert transformation. Defaults to 1.9. + """ + + def _transform_img(self, results: dict, mag: float) -> None: + """Invert the image.""" + img = results['img'] + img = mmcv.iminvert(img).astype(img.dtype) + img = mmcv.imtranslate(img, 1, 'horizontal') + img = mmcv.imtranslate(img, 1, 'vertical') + results['img'] = img.astype(img.dtype) diff --git a/mmdet/datasets/transforms/geometric.py b/mmdet/datasets/transforms/geometric.py index d2cd6be258f..9e76bcc15a3 100644 --- a/mmdet/datasets/transforms/geometric.py +++ b/mmdet/datasets/transforms/geometric.py @@ -752,3 +752,354 @@ def _transform_seg(self, results: dict, mag: float) -> None: direction='vertical', border_value=self.seg_ignore_label, interpolation='nearest') + + +@TRANSFORMS.register_module() +class BBoxShearX(ShearX): + def _transform_img(self, results: dict, mag: float) -> None: + img_orig = results['img'].copy() + (h_img, w_img, c_img) = img_orig.shape + + img = np.zeros_like(results['img'], dtype=np.float32) + for idx, (bbox, mask) in enumerate(zip(results['bboxes'], results['masks'])): + cy = (bbox[1] + bbox[3]) / 2 + + shear_matrix = np.array([[1, mag, -mag*cy], [0, 1, 0]], dtype=np.float32) + shear_img = cv2.warpAffine( + img_orig, + shear_matrix, + (w_img, h_img), + borderValue=tuple([0] * 3), + flags=cv2.INTER_LINEAR) + img = (1.0 - mask) * img + mask * shear_img + mask_max = np.max(results['masks'], axis=0) + img = (1.0 - mask_max) * img_orig + mask_max * img + + results['img'] = img + + def _transform_masks(self, results: dict, mag: float) -> None: + pass + + def _transform_seg(self, results: dict, mag: float) -> None: + pass + + +@TRANSFORMS.register_module() +class BBoxShearY(ShearY): + def _transform_img(self, results: dict, mag: float) -> None: + img_orig = results['img'].copy() + (h_img, w_img, c_img) = img_orig.shape + + img = np.zeros_like(results['img'], dtype=np.float32) + for idx, (bbox, mask) in enumerate(zip(results['bboxes'], results['masks'])): + cx = (bbox[0] + bbox[2]) / 2 + + shear_matrix = np.float32([[1, 0, 0], [mag, 1, -mag * cx]]) + shear_img = cv2.warpAffine( + img_orig, + shear_matrix, + (w_img, h_img), + borderValue=tuple([0] * 3), + flags=cv2.INTER_LINEAR) + img = (1.0 - mask) * img + mask * shear_img + mask_max = np.max(results['masks'], axis=0) + img = (1.0 - mask_max) * img_orig + mask_max * img + + results['img'] = img + + def _transform_masks(self, results: dict, mag: float) -> None: + pass + + def _transform_seg(self, results: dict, mag: float) -> None: + pass + + +@TRANSFORMS.register_module() +class BBoxRotate(Rotate): + def _transform_img(self, results: dict, mag: float) -> None: + img_orig = results['img'].copy() + (h_img, w_img, c_img) = img_orig.shape + + img = np.zeros_like(results['img'], dtype=np.float32) + for idx, (bbox, mask) in enumerate(zip(results['bboxes'], results['masks'])): + cx = (bbox[0] + bbox[2]) / 2 + cy = (bbox[1] + bbox[3]) / 2 + + translate_matrix = np.float32([[1, 0, w_img//2 - cx], [0, 1, h_img//2 - cy]]) + translated_img = cv2.warpAffine( + img_orig, + translate_matrix, + (w_img, h_img), + borderValue=tuple([0] * 3), + flags=cv2.INTER_LINEAR) + """Rotate the image.""" + rotated_img = mmcv.imrotate( + translated_img, + mag, + border_value=self.img_border_value, + interpolation=self.interpolation) + translate_matrix = np.float32([[1, 0, -w_img//2 + cx], [0, 1, -h_img//2 + cy]]) + rotated_img = cv2.warpAffine( + rotated_img, + translate_matrix, + (w_img, h_img), + borderValue=tuple([0] * 3), + flags=cv2.INTER_LINEAR) + + img = (1.0 - mask) * img + mask * rotated_img + mask_max = np.max(results['masks'], axis=0) + img = (1.0 - mask_max) * img_orig + mask_max * img + + results['img'] = img + + def _transform_masks(self, results: dict, mag: float) -> None: + pass + + def _transform_seg(self, results: dict, mag: float) -> None: + pass + + +@TRANSFORMS.register_module() +class BBoxTranslateX(TranslateX): + def __init__(self, **kwargs): + super().__init__(img_border_value=0, **kwargs) + + def _transform_img(self, results: dict, mag: float) -> None: + img_orig = results['img'].copy() + + img = np.zeros_like(results['img'], dtype=np.float32) + for idx, (bbox, mask) in enumerate(zip(results['bboxes'], results['masks'])): + w = bbox[2] - bbox[0] + + """Translate the image horizontally.""" + _mag = int(w * mag) + translated_img = mmcv.imtranslate( + img_orig, + _mag, + direction='horizontal', + border_value=self.img_border_value, + interpolation=self.interpolation) + + img = (1.0 - mask) * img + mask * translated_img + mask_max = np.max(results['masks'], axis=0) + img = (1.0 - mask_max) * img_orig + mask_max * img + + results['img'] = img + + def _transform_masks(self, results: dict, mag: float) -> None: + pass + + def _transform_seg(self, results: dict, mag: float) -> None: + pass + + +@TRANSFORMS.register_module() +class BBoxTranslateY(TranslateY): + def __init__(self, **kwargs): + super().__init__(img_border_value=0, **kwargs) + + def _transform_img(self, results: dict, mag: float) -> None: + img_orig = results['img'].copy() + + img = np.zeros_like(results['img'], dtype=np.float32) + for idx, (bbox, mask) in enumerate(zip(results['bboxes'], results['masks'])): + h = bbox[3] - bbox[1] + + """Translate the image horizontally.""" + _mag = int(h * mag) + translated_img = mmcv.imtranslate( + img_orig, + _mag, + direction='vertical', + border_value=self.img_border_value, + interpolation=self.interpolation) + + img = (1.0 - mask) * img + mask * translated_img + mask_max = np.max(results['masks'], axis=0) + img = (1.0 - mask_max) * img_orig + mask_max * img + + results['img'] = img + + def _transform_masks(self, results: dict, mag: float) -> None: + pass + + def _transform_seg(self, results: dict, mag: float) -> None: + pass + + +@TRANSFORMS.register_module() +class BgShearX(ShearX): + def __init__(self, **kwargs): + super().__init__(img_border_value=0, **kwargs) + def _transform_img(self, results: dict, mag: float) -> None: + img_orig = results['img'].copy() + (h_img, w_img, c_img) = img_orig.shape + + """Shear the image horizontally.""" + mask_max = np.max(results['masks'], axis=0) + + shear_matrix = np.array([[1, mag, -mag*h_img//2], [0, 1, 0]], dtype=np.float32) + sheared_mask = cv2.warpAffine( + mask_max, + shear_matrix, + (w_img, h_img), + borderValue=tuple([0] * 3), + flags=cv2.INTER_LINEAR) + sheared_img = cv2.warpAffine( + img_orig, + shear_matrix, + (w_img, h_img), + borderValue=tuple([0] * 3), + flags=cv2.INTER_LINEAR) + + union_mask = np.max([mask_max, sheared_mask], axis=0) + img = (1.0 - union_mask) * sheared_img + union_mask * img_orig + + results['img'] = img + + def _transform_masks(self, results: dict, mag: float) -> None: + pass + + def _transform_seg(self, results: dict, mag: float) -> None: + pass + + +@TRANSFORMS.register_module() +class BgShearY(ShearY): + def __init__(self, **kwargs): + super().__init__(img_border_value=0, **kwargs) + def _transform_img(self, results: dict, mag: float) -> None: + img_orig = results['img'].copy() + (h_img, w_img, c_img) = img_orig.shape + + """Shear the image vertically.""" + mask_max = np.max(results['masks'], axis=0) + shear_matrix = np.array([[1, 0, 0], [mag, 1, -mag*w_img//2]], dtype=np.float32) + sheared_mask = cv2.warpAffine( + mask_max, + shear_matrix, + (w_img, h_img), + borderValue=tuple([0] * 3), + flags=cv2.INTER_LINEAR) + sheared_img = cv2.warpAffine( + img_orig, + shear_matrix, + (w_img, h_img), + borderValue=tuple([0] * 3), + flags=cv2.INTER_LINEAR) + + union_mask = np.max([mask_max, sheared_mask], axis=0) + img = (1.0 - union_mask) * sheared_img + union_mask * img_orig + + results['img'] = img + + def _transform_masks(self, results: dict, mag: float) -> None: + pass + + def _transform_seg(self, results: dict, mag: float) -> None: + pass + + +@TRANSFORMS.register_module() +class BgRotate(Rotate): + def __init__(self, **kwargs): + super().__init__(img_border_value=0, **kwargs) + def _transform_img(self, results: dict, mag: float) -> None: + img_orig = results['img'].copy() + + """Rotate the image.""" + mask_max = np.max(results['masks'], axis=0) + rotated_mask = mmcv.imrotate( + mask_max, + mag, + border_value=0, + interpolation=self.interpolation) + rotated_img = mmcv.imrotate( + img_orig, + mag, + border_value=self.img_border_value, + interpolation=self.interpolation) + + union_mask = np.max([mask_max, rotated_mask], axis=0) + img = (1.0 - union_mask) * rotated_img + union_mask * img_orig + + results['img'] = img + + def _transform_masks(self, results: dict, mag: float) -> None: + pass + + def _transform_seg(self, results: dict, mag: float) -> None: + pass + + +@TRANSFORMS.register_module() +class BgTranslateX(TranslateX): + def __init__(self, **kwargs): + super().__init__(img_border_value=0, **kwargs) + + def _transform_img(self, results: dict, mag: float) -> None: + img_orig = results['img'].copy() + (h_img, w_img, c_img) = img_orig.shape + + """Translate the image horizontally.""" + _mag = w_img * mag + mask_max = np.max(results['masks'], axis=0) + translated_mask = mmcv.imtranslate( + mask_max, + _mag, + direction='horizontal', + border_value=0, + interpolation=self.interpolation) + translated_img = mmcv.imtranslate( + img_orig, + _mag, + direction='horizontal', + border_value=self.img_border_value, + interpolation=self.interpolation) + + union_mask = np.max([mask_max, translated_mask], axis=0) + img = (1.0 - union_mask) * translated_img + union_mask * img_orig + + results['img'] = img + + def _transform_masks(self, results: dict, mag: float) -> None: + pass + + def _transform_seg(self, results: dict, mag: float) -> None: + pass + + +@TRANSFORMS.register_module() +class BgTranslateY(TranslateY): + def __init__(self, **kwargs): + super().__init__(img_border_value=0, **kwargs) + def _transform_img(self, results: dict, mag: float) -> None: + img_orig = results['img'].copy() + (h_img, w_img, c_img) = img_orig.shape + + """Translate the image horizontally.""" + _mag = h_img * mag + mask_max = np.max(results['masks'], axis=0) + translated_mask = mmcv.imtranslate( + mask_max, + _mag, + direction='vertical', + border_value=0, + interpolation=self.interpolation) + translated_img = mmcv.imtranslate( + img_orig, + _mag, + direction='vertical', + border_value=self.img_border_value, + interpolation=self.interpolation) + + union_mask = np.max([mask_max, translated_mask], axis=0) + img = (1.0 - union_mask) * translated_img + union_mask * img_orig + + results['img'] = img + + def _transform_masks(self, results: dict, mag: float) -> None: + pass + + def _transform_seg(self, results: dict, mag: float) -> None: + pass diff --git a/mmdet/datasets/transforms/oa_mix.py b/mmdet/datasets/transforms/oa_mix.py new file mode 100644 index 00000000000..2da1c7f590d --- /dev/null +++ b/mmdet/datasets/transforms/oa_mix.py @@ -0,0 +1,385 @@ +import cv2 +import numpy as np +from typing import List, Tuple + +from mmcv.transforms import BaseTransform, Compose + +from mmdet.registry import TRANSFORMS + + +def get_transforms(version: str) -> List[Compose]: + if version == 'color': + transforms = [ + dict(type='AutoContrast'), dict(type='Brightness'), dict(type='Color'), + dict(type='Contrast'), dict(type='Equalize'), dict(type='Invert4Mix'), + dict(type='Posterize'), dict(type='Sharpness'), dict(type='Solarize'), + dict(type='SolarizeAdd') + ] + elif version == 'geo': + transforms = [ + dict(type='BgShearX'), dict(type='BgShearY'), dict(type='BgRotate'), + dict(type='BgTranslateX'), dict(type='BgTranslateY'), + dict(type='BBoxShearX'), dict(type='BBoxShearY'), dict(type='BBoxRotate'), + dict(type='BBoxTranslateX'), dict(type='BBoxTranslateY'), + ] + elif version == 'oamix': + transforms = [ + dict(type='AutoContrast'), dict(type='Brightness'), dict(type='Color'), + dict(type='Contrast'), dict(type='Equalize'), dict(type='Invert4Mix'), + dict(type='Posterize'), dict(type='Sharpness'), + dict(type='BgShearX'), dict(type='BgShearY'), dict(type='BgRotate'), + dict(type='BgTranslateX'), dict(type='BgTranslateY'), + dict(type='BBoxShearX'), dict(type='BBoxShearY'), dict(type='BBoxRotate'), + dict(type='BBoxTranslateX'), dict(type='BBoxTranslateY'), + ] + else: + raise TypeError(f"Invalid version: {version}. Please add the version to the get_transforms function.") + transforms = [Compose(transforms) for transforms in transforms] + return transforms + + +def bbox_overlaps_np(bboxes1: np.ndarray, bboxes2: np.ndarray, eps: float = 1e-6) -> np.ndarray: + """Calculate overlap between two set of bboxes. + + Args: + bboxes1 (ndarray): shape (B, m, 4) in format or empty. + bboxes2 (ndarray): shape (B, n, 4) in format or empty. + B indicates the batch dim, in shape (B1, B2, ..., Bn). + If ``is_aligned`` is ``True``, then m and n must be equal. + eps (float, optional): A value added to the denominator for numerical + stability. Default 1e-6. + Returns: + ndarray: shape (m, n) if ``is_aligned`` is False else shape (m,) + """ + if len(bboxes2) == 0: + return np.zeros((bboxes1.shape[-2], 0)) + assert (bboxes1.shape[-1] == 4 or bboxes1.shape[0] == 0) + assert (bboxes2.shape[-1] == 4 or bboxes2.shape[0] == 0) + + assert bboxes1.shape[:-2] == bboxes2.shape[:-2] + batch_shape = bboxes1.shape[:-2] + + rows = bboxes1.shape[-2] + cols = bboxes2.shape[-2] + + if rows * cols == 0: + return np.zeros(batch_shape + (rows, cols)) + + area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * ( + bboxes1[..., 3] - bboxes1[..., 1]) + area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * ( + bboxes2[..., 3] - bboxes2[..., 1]) + + lt = np.maximum(bboxes1[..., :, None, :2], + bboxes2[..., None, :, :2]) # [B, rows, cols, 2] + rb = np.minimum(bboxes1[..., :, None, 2:], + bboxes2[..., None, :, 2:]) # [B, rows, cols, 2] + + wh = np.clip(rb - lt, a_min=0, a_max=None) + overlap = wh[..., 0] * wh[..., 1] + + union = area1[..., None] + area2[..., None, :] - overlap + + union = np.maximum(union, eps) + ious = overlap / union + return ious + + +@TRANSFORMS.register_module() +class OAMix(BaseTransform): + r"""Data augmentation method in `Object-Aware Domain Generalization for Object Detection + `_. + + Refer to https://github.com/woojulee24/OA-DG for implementation details. + + Required Keys: + + - img + - gt_bboxes (BaseBoxes[torch.float32]) + + Modified Keys: + + - img + - gt_bboxes + + Args: + version (str): The version of the augmentation method. + Defaults to 'oamix'. + aug_prob_coeff (float): The coefficient of the augmentation probability. + Defaults to 1.0. + mixture_width (int): The number of augmentation operations in the mixture. + Defaults to 3. + mixture_depth (int): The depth of augmentation operations in the mixture. + If mixture_depth is -1, the depth is randomly sampled from [1, 4]. + Defaults to -1. + box_scale (tuple): The scale of the random bounding boxes. + Defaults to (0.01, 0.1). + box_ratio (tuple): The aspect ratio of the random bounding boxes. + Defaults to (3, 1/3). + sigma_ratio (float): The ratio of the sigma for the Gaussian blur. + Defaults to 0.3. + score_thresh (float): The threshold of the saliency score. + Defaults to 10. + """ + def __init__(self, + version: str = "oamix", + aug_prob_coeff: float = 1.0, + mixture_width: int = 3, + mixture_depth: int = -1, + box_scale: tuple = (0.01, 0.1), + box_ratio: tuple = (3, 0.33), + sigma_ratio: float = 0.3, + score_thresh: float = 10.0) -> None: + assert version in ['color', 'geo', 'oamix'], "The version should be either 'color', 'geo', or 'oamix'." \ + "Please add the version to the get_transforms function." + assert aug_prob_coeff > 0, "The augmentation probability coefficient should be greater than 0." + assert isinstance(mixture_width, int) and mixture_width > 0, "The mixture width should be greater than 0." + assert isinstance(mixture_depth, int) and mixture_depth >= -1, "The mixture depth should be greater than or equal to -1." + assert isinstance(box_scale, tuple) and len(box_scale) == 2, "The box scale should be a tuple of 2 elements." + assert isinstance(box_ratio, tuple) and len(box_ratio) == 2, "The box ratio should be a tuple of 2 elements." + assert 0 <= sigma_ratio <= 1, "The sigma ratio should be in the range [0, 1]." + assert score_thresh >= 0, "The score threshold should be greater than or equal to 0." + super(OAMix, self).__init__() + + self.version = version + self.transforms = get_transforms(version) + self.aug_prob_coeff = aug_prob_coeff + self.mixture_width = mixture_width + self.mixture_depth = mixture_depth + self.box_scale = box_scale + self.box_ratio = box_ratio + self.sigma_ratio = sigma_ratio + self.score_thresh = score_thresh + + def transform(self, results) -> dict: + """ The transform function. """ + img = results['img'] + gt_bboxes = results['gt_bboxes'].numpy() + gt_masks = self.get_masks(gt_bboxes, img.shape, use_blur=True) + + results['img'] = self.oamix(img.copy(), gt_bboxes, gt_masks) + + return results + + def oamix(self, img_orig: np.ndarray, gt_bboxes: np.ndarray, gt_masks: List[np.ndarray]) -> np.ndarray: + ws = np.float32(np.random.dirichlet([self.aug_prob_coeff] * self.mixture_width)) + img_mix = np.zeros_like(img_orig, dtype=np.float32) + for i in range(self.mixture_width): + """ Multi-level transformation """ + depth = self.mixture_depth if self.mixture_depth > 0 else np.random.randint(1, 4) + img_aug = img_orig.copy() + for _ in range(depth): + img_aug = self.multilivel_transform(img_aug, gt_bboxes, gt_masks) + img_mix += ws[i] * img_aug + + """ Object-aware mixing """ + img_oamix = self.object_aware_mixing(img_orig, img_mix, gt_bboxes, gt_masks) + + return np.asarray(img_oamix, dtype=img_orig.dtype) + + def multilivel_transform(self, img: np.ndarray, gt_bboxes: np.ndarray, gt_masks: List[np.ndarray]) -> np.ndarray: + rand_bboxes = self.get_random_bboxes(img.shape, num_bboxes=(1, 3)) + rand_masks = self.get_masks(rand_bboxes, img.shape) + + img_tmp = np.zeros_like(img, dtype=np.float32) + for rand_mask in rand_masks: + img_tmp += rand_mask * self.aug(img, gt_bboxes, gt_masks) + union_mask = np.max(rand_masks, axis=0) + img_aug = np.asarray( + img_tmp + (1.0 - union_mask) * self.aug(img, gt_bboxes, gt_masks), dtype=img.dtype + ) + return img_aug + + def get_random_bboxes(self, img_shape: Tuple, num_bboxes: Tuple[int, int], + max_iters: int = 50, eps: float = 1e-6) -> np.ndarray: + assert max_iters > 0, "The maximum number of iterations should be greater than 0." + + h_img, w_img, _ = img_shape + num_target_bboxes = np.random.randint(*num_bboxes) + rand_bboxes = np.zeros((0, 4)) + for i in range(max_iters): + if len(rand_bboxes) >= num_target_bboxes: + break + scale = np.random.uniform(*self.box_scale) + aspect_ratio = np.random.uniform(*self.box_ratio) + + height = scale * h_img + width = height * aspect_ratio + if width > w_img or height > h_img: + continue # Invalid bbox (out of the image) + + xmin = np.random.uniform(0, w_img - width) + ymin = np.random.uniform(0, h_img - height) + xmax = xmin + width + ymax = ymin + height + + rand_bbox = np.array([[xmin, ymin, xmax, ymax]]) + ious = bbox_overlaps_np(rand_bbox, rand_bboxes) + if np.sum(ious) > eps: + continue # Invalid bbox (overlapping with existing bboxes) + + rand_bboxes = np.concatenate([rand_bboxes, rand_bbox], axis=0) + + return rand_bboxes + + def get_masks(self, bboxes: np.ndarray, img_shape: tuple, use_blur: bool = False) -> List[np.ndarray]: + """ Get the masks of the bounding boxes. """ + mask_list = [] + for bbox in bboxes: + if len(bbox.shape) == 2 and bbox.shape[1] == 4: + bbox = bbox[0] + mask = np.zeros(img_shape, dtype=np.float32) + mask[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])] = 1.0 + if use_blur: + sigma_x = (bbox[2] - bbox[0]) * self.sigma_ratio / 3 * 2 + sigma_y = (bbox[3] - bbox[1]) * self.sigma_ratio / 3 * 2 + if not (sigma_x <= 0 or sigma_y <= 0): + mask = cv2.GaussianBlur(mask, (0, 0), sigmaX=sigma_x.item(), sigmaY=sigma_y.item()) + mask = cv2.resize(mask, (img_shape[1], img_shape[0]), interpolation=cv2.INTER_LINEAR) + mask_list.append(mask) + return mask_list + + def aug(self, img: np.ndarray, gt_bboxes: np.ndarray, gt_masks: List[np.ndarray]): + op = np.random.choice(self.transforms) + op_kwargs = {'img': img, 'bboxes': gt_bboxes, 'masks': gt_masks, 'img_shape': img.shape} + img_aug = op(op_kwargs)['img'] + return img_aug + + def object_aware_mixing(self, + img_orig: np.ndarray, + img_mix: np.ndarray, + gt_bboxes: np.ndarray, + gt_masks: List[np.ndarray]) -> np.ndarray: + gt_scores = self.get_saliency_scores(img_orig, gt_bboxes) + + target_indices = gt_scores < self.score_thresh + target_bboxes = gt_bboxes[target_indices] + target_masks = [gt_masks[i] for i in np.where(target_indices)[0]] + target_m = np.random.uniform(0.0, 0.5, len(target_bboxes)).astype(np.float32) + + rand_bboxes = self.get_random_bboxes(img_orig.shape, num_bboxes=(3, 5)) + rand_masks = self.get_masks(rand_bboxes, img_orig.shape, use_blur=True) + rand_m = np.random.uniform(0.0, 1.0, len(rand_bboxes)).astype(np.float32) + + target_bboxes = np.vstack((target_bboxes, rand_bboxes)) + target_masks.extend(rand_masks) + target_m = np.concatenate((target_m, rand_m)) + + orig = np.zeros_like(img_orig, dtype=np.float32) + aug = np.zeros_like(img_orig, dtype=np.float32) + mask_sum = np.zeros_like(img_orig, dtype=np.float32) + + for bbox, mask, m in zip(target_bboxes, target_masks, target_m): + mask_sum += mask + mask_max = np.maximum(mask_sum, mask) + mask_overlap = mask_sum - mask_max + overlap_factor = (mask - mask_overlap * 0.5) + + orig += (1.0 - m) * img_orig * overlap_factor + aug += m * img_mix * overlap_factor + mask_sum = mask_max + + img_oamix = orig + aug + + m = np.random.beta(self.aug_prob_coeff, self.aug_prob_coeff) + + img_oamix += (1.0 - m) * img_orig * (1.0 - mask_sum) + img_oamix += m * img_mix * (1.0 - mask_sum) + img_oamix = np.clip(img_oamix, 0, 255) + + return img_oamix + + @staticmethod + def get_saliency_scores(img: np.ndarray, bboxes: np.ndarray) -> np.ndarray: + saliency_scores = [] + for bbox in np.asarray(bboxes, dtype=np.int32): + bbox_img = img[bbox[1]:bbox[3], bbox[0]:bbox[2]] + + saliency = cv2.saliency.StaticSaliencySpectralResidual_create() + (success, saliency_map) = saliency.computeSaliency(bbox_img) + score = np.mean((saliency_map * 255).astype("uint8")) + saliency_scores.append(score) + return np.asarray(saliency_scores, dtype=np.uint8) + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(version={self.version}, aug_prob_coeff={self.aug_prob_coeff}, ' \ + f'mixture_width={self.mixture_width}, mixture_depth={self.mixture_depth}, ' \ + f'box_scale={self.box_scale}, box_ratio={self.box_ratio}, ' \ + f'sigma_ratio={self.sigma_ratio}, score_thresh={self.score_thresh})' + return repr_str + + @staticmethod + def _load_example_data() -> Tuple[np.ndarray, np.ndarray]: + img = cv2.imread("../demo/demo.jpg") + gt_bboxes = np.array( + [[609.2460327148438, 111.9759292602539, 635.9223022460938, 137.42437744140625], + [480.33782958984375, 110.44952392578125, 521.1831665039062, 129.6164093017578], + [295.34356689453125, 116.82196807861328, 379.7244873046875, 149.78955078125], + [219.83975219726562, 177.6780548095703, 455.0238952636719, 382.3981628417969], + [0.2255704551935196, 111.30818176269531, 62.27484130859375, 145.1354522705078], + [191.24462890625, 108.73335266113281, 297.60186767578125, 155.57919311523438], + [431.61810302734375, 105.31916046142578, 482.30120849609375, 132.21238708496094], + [589.4951171875, 111.1348648071289, 616.8546752929688, 126.33065032958984], + [167.97503662109375, 106.92251586914062, 211.0978546142578, 140.34495544433594], + [270.0672912597656, 104.84465789794922, 326.28662109375, 128.14691162109375], + [395.87152099609375, 111.33557891845703, 433.2410583496094, 132.824462890625], + [60.50261306762695, 94.38719940185547, 85.39842224121094, 105.71919250488281], + [373.8731384277344, 136.39341735839844, 434.0091857910156, 187.2471466064453], + [141.07984924316406, 96.27764129638672, 166.647705078125, 105.06587982177734], + [224.4158477783203, 97.63524627685547, 249.95281982421875, 107.63406372070312], + [556.1692504882812, 110.58447265625, 588.9140014648438, 127.4863052368164], + [77.04759979248047, 90.36402130126953, 97.74053192138672, 98.54667663574219]] + ) + return img, gt_bboxes + + def _test_transformations(self) -> None: + img, gt_bboxes = self._load_example_data() + gt_masks = self.get_masks(gt_bboxes, img.shape, use_blur=True) + + transform_type_list = ['AutoContrast', 'Brightness', 'Color', 'Contrast', 'Equalize', 'Invert4Mix', + 'Posterize', 'Sharpness', 'Solarize', 'SolarizeAdd', + 'BgShearX', 'BgShearY', 'BgRotate', 'BgTranslateX', 'BgTranslateY', + 'BBoxShearX', 'BBoxShearY', 'BBoxRotate', 'BBoxTranslateX', 'BBoxTranslateY'] + for type in transform_type_list: + op = Compose(dict(type=type)) + op_kwargs = {'img': img, 'bboxes': gt_bboxes, 'masks': gt_masks, 'img_shape': img.shape} + _ = op(op_kwargs)['img'] + + return + + def _test_multilevel_transformations(self) -> None: + img_orig, gt_bboxes = self._load_example_data() + gt_masks = self.get_masks(gt_bboxes, img_orig.shape, use_blur=True) + + for idx in range(10): + ws = np.float32(np.random.dirichlet([self.aug_prob_coeff] * self.mixture_width)) + img_mix = np.zeros_like(img_orig, dtype=np.float32) + for i in range(self.mixture_width): + """ Multi-level transformation """ + depth = self.mixture_depth if self.mixture_depth > 0 else np.random.randint(1, 4) + img_aug = img_orig.copy() + for _ in range(depth): + img_aug = self.multilivel_transform(img_aug, gt_bboxes, gt_masks) + img_mix += ws[i] * img_aug + + return + + def _test_objectaware_mixing(self) -> None: + img_orig, gt_bboxes = self._load_example_data() + gt_masks = self.get_masks(gt_bboxes, img_orig.shape, use_blur=True) + + ws = np.float32(np.random.dirichlet([self.aug_prob_coeff] * self.mixture_width)) + img_mix = np.zeros_like(img_orig, dtype=np.float32) + for i in range(self.mixture_width): + """ Multi-level transformation """ + depth = self.mixture_depth if self.mixture_depth > 0 else np.random.randint(1, 4) + img_aug = img_orig.copy() + for _ in range(depth): + img_aug = self.multilivel_transform(img_aug, gt_bboxes, gt_masks) + img_mix += ws[i] * img_aug + + """ Object-aware mixing """ + img_oamix = self.object_aware_mixing(img_orig, img_mix, gt_bboxes, gt_masks) + + return From 471532c9642ff3f6ca82978770fb0aa6ebcc19c6 Mon Sep 17 00:00:00 2001 From: dazory Date: Tue, 20 Aug 2024 13:20:52 +0000 Subject: [PATCH 2/7] add README and oamix --- README.md | 460 +++------------------------- mmdet/datasets/transforms/oa_mix.py | 275 ++++++++++------- resources/oamix_examples.gif | Bin 0 -> 5278753 bytes resources/oamix_examples.png | Bin 0 -> 262113 bytes 4 files changed, 209 insertions(+), 526 deletions(-) mode change 100644 => 100755 mmdet/datasets/transforms/oa_mix.py create mode 100644 resources/oamix_examples.gif create mode 100644 resources/oamix_examples.png diff --git a/README.md b/README.md index 34f7f0b8f90..420006ba8a3 100644 --- a/README.md +++ b/README.md @@ -1,455 +1,69 @@ -
- -
 
-
- OpenMMLab website - - - HOT - - -      - OpenMMLab platform - - - TRY IT OUT - - -
-
 
+# MMDetection with OA-Mix -[![PyPI](https://img.shields.io/pypi/v/mmdet)](https://pypi.org/project/mmdet) -[![docs](https://img.shields.io/badge/docs-latest-blue)](https://mmdetection.readthedocs.io/en/latest/) -[![badge](https://github.com/open-mmlab/mmdetection/workflows/build/badge.svg)](https://github.com/open-mmlab/mmdetection/actions) -[![codecov](https://codecov.io/gh/open-mmlab/mmdetection/branch/main/graph/badge.svg)](https://codecov.io/gh/open-mmlab/mmdetection) -[![license](https://img.shields.io/github/license/open-mmlab/mmdetection.svg)](https://github.com/open-mmlab/mmdetection/blob/main/LICENSE) -[![open issues](https://isitmaintained.com/badge/open/open-mmlab/mmdetection.svg)](https://github.com/open-mmlab/mmdetection/issues) -[![issue resolution](https://isitmaintained.com/badge/resolution/open-mmlab/mmdetection.svg)](https://github.com/open-mmlab/mmdetection/issues) -[![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_demo.svg)](https://openxlab.org.cn/apps?search=mmdet) - -[📘Documentation](https://mmdetection.readthedocs.io/en/latest/) | -[🛠️Installation](https://mmdetection.readthedocs.io/en/latest/get_started.html) | -[👀Model Zoo](https://mmdetection.readthedocs.io/en/latest/model_zoo.html) | -[🆕Update News](https://mmdetection.readthedocs.io/en/latest/notes/changelog.html) | -[🚀Ongoing Projects](https://github.com/open-mmlab/mmdetection/projects) | -[🤔Reporting Issues](https://github.com/open-mmlab/mmdetection/issues/new/choose) - -
- -
- -English | [简体中文](README_zh-CN.md) - -
- -
- - - - - - - - - - - - - - - - - -
- -
- +
+
## Introduction -MMDetection is an open source object detection toolbox based on PyTorch. It is -a part of the [OpenMMLab](https://openmmlab.com/) project. - -The main branch works with **PyTorch 1.8+**. - - - -
-Major features +This repository is a fork of the [mmdetection](https://github.com/open-mmlab/mmdetection) toolbox with the implementation of OA-Mix, +a novel data augmentation technique designed to improve domain generalization in single-domain object detection. +OA-Mix is part of the [Object-Aware Domain Generalization (OA-DG)](https://github.com/woojulee24/OA-DG) framework, +introduced in the paper [Object-Aware Domain Generalization for Object Detection](https://ojs.aaai.org/index.php/AAAI/article/view/28076). -- **Modular Design** +This repository has been created to showcase the OA-Mix method. +The method enhances model robustness against domain shifts by generating diverse multi-domain data while preserving object annotations. - We decompose the detection framework into different components and one can easily construct a customized object detection framework by combining different modules. +For more information on the details of OA-Mix and its use cases, +please refer to the paper [Object-Aware Domain Generalization for Object Detection](https://ojs.aaai.org/index.php/AAAI/article/view/28076), presented at AAAI 2024. -- **Support of multiple tasks out of box** +## Example of OA-Mix - The toolbox directly supports multiple detection tasks such as **object detection**, **instance segmentation**, **panoptic segmentation**, and **semi-supervised object detection**. - -- **High efficiency** - - All basic bbox and mask operations run on GPUs. The training speed is faster than or comparable to other codebases, including [Detectron2](https://github.com/facebookresearch/detectron2), [maskrcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark) and [SimpleDet](https://github.com/TuSimple/simpledet). - -- **State of the art** - - The toolbox stems from the codebase developed by the *MMDet* team, who won [COCO Detection Challenge](http://cocodataset.org/#detection-leaderboard) in 2018, and we keep pushing it forward. - The newly released [RTMDet](configs/rtmdet) also obtains new state-of-the-art results on real-time instance segmentation and rotated object detection tasks and the best parameter-accuracy trade-off on object detection. - -
- -Apart from MMDetection, we also released [MMEngine](https://github.com/open-mmlab/mmengine) for model training and [MMCV](https://github.com/open-mmlab/mmcv) for computer vision research, which are heavily depended on by this toolbox. - -## What's New - -💎 **We have released the pre-trained weights for MM-Grounding-DINO Swin-B and Swin-L, welcome to try and give feedback.** - -### Highlight - -**v3.3.0** was released in 5/1/2024: - -**[MM-Grounding-DINO: An Open and Comprehensive Pipeline for Unified Object Grounding and Detection](https://arxiv.org/abs/2401.02361)** - -Grounding DINO is a grounding pre-training model that unifies 2d open vocabulary object detection and phrase grounding, with wide applications. However, its training part has not been open sourced. Therefore, we propose MM-Grounding-DINO, which not only serves as an open source replication version of Grounding DINO, but also achieves significant performance improvement based on reconstructed data types, exploring different dataset combinations and initialization strategies. Moreover, we conduct evaluations from multiple dimensions, including OOD, REC, Phrase Grounding, OVD, and Fine-tune, to fully excavate the advantages and disadvantages of Grounding pre-training, hoping to provide inspiration for future work. - -code: [mm_grounding_dino/README.md](configs/mm_grounding_dino/README.md) +Below is an example showing the results of OA-Mix:
- +
-We are excited to announce our latest work on real-time object recognition tasks, **RTMDet**, a family of fully convolutional single-stage detectors. RTMDet not only achieves the best parameter-accuracy trade-off on object detection from tiny to extra-large model sizes but also obtains new state-of-the-art performance on instance segmentation and rotated object detection tasks. Details can be found in the [technical report](https://arxiv.org/abs/2212.07784). Pre-trained models are [here](configs/rtmdet). +## Performance Improvement with OA-Mix -[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/real-time-instance-segmentation-on-mscoco)](https://paperswithcode.com/sota/real-time-instance-segmentation-on-mscoco?p=rtmdet-an-empirical-study-of-designing-real) -[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/object-detection-in-aerial-images-on-dota-1)](https://paperswithcode.com/sota/object-detection-in-aerial-images-on-dota-1?p=rtmdet-an-empirical-study-of-designing-real) -[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/object-detection-in-aerial-images-on-hrsc2016)](https://paperswithcode.com/sota/object-detection-in-aerial-images-on-hrsc2016?p=rtmdet-an-empirical-study-of-designing-real) +Below is a performance comparison between a baseline object detection model and the same model with OA-Mix applied: -| Task | Dataset | AP | FPS(TRT FP16 BS1 3090) | -| ------------------------ | ------- | ------------------------------------ | ---------------------- | -| Object Detection | COCO | 52.8 | 322 | -| Instance Segmentation | COCO | 44.6 | 188 | -| Rotated Object Detection | DOTA | 78.9(single-scale)/81.3(multi-scale) | 121 | +| Model | Dataset | mAP | Gauss. | Shot | Impulse | Defocus | Glass | Motion | Zoom | Snow | Frost | Fog | Bright | Contrast | Elastic | Pixel | JPEG | mPC | +| :-------------------: | :----------: | :--: | :----: | :--: | :-----: | :-----: | :---: | :----: | :--: | :--: | :---: | :--: | :----: | :------: | :-----: | ----- | :--: | :--: | +| Faster R-CNN | Cityscapes-C | 42.2 | 0.5 | 1.1 | 1.1 | 17.2 | 16.5 | 18.3 | 2.1 | 2.2 | 12.3 | 29.8 | 32.0 | 24.1 | 40.1 | 18.7 | 15.1 | 15.4 | +| Faster R-CNN + OA-Mix | Cityscapes-C | 42.7 | 7.2 | 9.6 | 7.7 | 22.8 | 18.8 | 21.9 | 5.4 | 5.2 | 23.6 | 37.3 | 38.7 | 31.9 | 40.2 | 22.2 | 20.2 | 20.8 | -
- -
- -## Installation - -Please refer to [Installation](https://mmdetection.readthedocs.io/en/latest/get_started.html) for installation instructions. - -## Getting Started - -Please see [Overview](https://mmdetection.readthedocs.io/en/latest/get_started.html) for the general introduction of MMDetection. - -For detailed user guides and advanced guides, please refer to our [documentation](https://mmdetection.readthedocs.io/en/latest/): - -- User Guides - -
- - - [Train & Test](https://mmdetection.readthedocs.io/en/latest/user_guides/index.html#train-test) - - [Learn about Configs](https://mmdetection.readthedocs.io/en/latest/user_guides/config.html) - - [Inference with existing models](https://mmdetection.readthedocs.io/en/latest/user_guides/inference.html) - - [Dataset Prepare](https://mmdetection.readthedocs.io/en/latest/user_guides/dataset_prepare.html) - - [Test existing models on standard datasets](https://mmdetection.readthedocs.io/en/latest/user_guides/test.html) - - [Train predefined models on standard datasets](https://mmdetection.readthedocs.io/en/latest/user_guides/train.html) - - [Train with customized datasets](https://mmdetection.readthedocs.io/en/latest/user_guides/train.html#train-with-customized-datasets) - - [Train with customized models and standard datasets](https://mmdetection.readthedocs.io/en/latest/user_guides/new_model.html) - - [Finetuning Models](https://mmdetection.readthedocs.io/en/latest/user_guides/finetune.html) - - [Test Results Submission](https://mmdetection.readthedocs.io/en/latest/user_guides/test_results_submission.html) - - [Weight initialization](https://mmdetection.readthedocs.io/en/latest/user_guides/init_cfg.html) - - [Use a single stage detector as RPN](https://mmdetection.readthedocs.io/en/latest/user_guides/single_stage_as_rpn.html) - - [Semi-supervised Object Detection](https://mmdetection.readthedocs.io/en/latest/user_guides/semi_det.html) - - [Useful Tools](https://mmdetection.readthedocs.io/en/latest/user_guides/index.html#useful-tools) - -
- -- Advanced Guides - -
- - - [Basic Concepts](https://mmdetection.readthedocs.io/en/latest/advanced_guides/index.html#basic-concepts) - - [Component Customization](https://mmdetection.readthedocs.io/en/latest/advanced_guides/index.html#component-customization) - - [How to](https://mmdetection.readthedocs.io/en/latest/advanced_guides/index.html#how-to) +## mmdetection Readme -
+For information on mmdetection please refer to the [mmdetection readme](MMDETECTION_README.md). -We also provide object detection colab tutorial [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](demo/MMDet_Tutorial.ipynb) and instance segmentation colab tutorial [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](demo/MMDet_InstanceSeg_Tutorial.ipynb). - -To migrate from MMDetection 2.x, please refer to [migration](https://mmdetection.readthedocs.io/en/latest/migration.html). - -## Overview of Benchmark and Model Zoo - -Results and models are available in the [model zoo](docs/en/model_zoo.md). - -
- Architectures -
- - - - - - - - - - - - - - - - - -
- Object Detection - - Instance Segmentation - - Panoptic Segmentation - - Other -
- - - - - - - -
  • Contrastive Learning
  • - - -
  • Distillation
  • - -
  • Semi-Supervised Object Detection
  • - - -
    - -
    - Components -
    - - - - - - - - - - - - - - - - - -
    - Backbones - - Necks - - Loss - - Common -
    - - - - - - - -
    - -Some other methods are also supported in [projects using MMDetection](./docs/en/notes/projects.md). - -## FAQ - -Please refer to [FAQ](docs/en/notes/faq.md) for frequently asked questions. - -## Contributing +## Installation -We appreciate all contributions to improve MMDetection. Ongoing projects can be found in out [GitHub Projects](https://github.com/open-mmlab/mmdetection/projects). Welcome community users to participate in these projects. Please refer to [CONTRIBUTING.md](.github/CONTRIBUTING.md) for the contributing guideline. +Please refer to [INSTALL.md](INSTALL.md) for installation and dataset preparation. -## Acknowledgement +## Get Started -MMDetection is an open source project that is contributed by researchers and engineers from various colleges and companies. We appreciate all the contributors who implement their methods or add new features, as well as users who give valuable feedbacks. -We wish that the toolbox and benchmark could serve the growing research community by providing a flexible toolkit to reimplement existing methods and develop their own new detectors. +Please see [GETTING_STARTED.md](GETTING_STARTED.md) for the basic usage of MMDetection. ## Citation If you use this toolbox or benchmark in your research, please cite this project. ``` -@article{mmdetection, - title = {{MMDetection}: Open MMLab Detection Toolbox and Benchmark}, - author = {Chen, Kai and Wang, Jiaqi and Pang, Jiangmiao and Cao, Yuhang and - Xiong, Yu and Li, Xiaoxiao and Sun, Shuyang and Feng, Wansen and - Liu, Ziwei and Xu, Jiarui and Zhang, Zheng and Cheng, Dazhi and - Zhu, Chenchen and Cheng, Tianheng and Zhao, Qijie and Li, Buyu and - Lu, Xin and Zhu, Rui and Wu, Yue and Dai, Jifeng and Wang, Jingdong - and Shi, Jianping and Ouyang, Wanli and Loy, Chen Change and Lin, Dahua}, - journal= {arXiv preprint arXiv:1906.07155}, - year={2019} +@inproceedings{lee2024object, + title={Object-Aware Domain Generalization for Object Detection}, + author={Lee, Wooju and Hong, Dasol and Lim, Hyungtae and Myung, Hyun}, + booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, + volume={38}, + number={4}, + pages={2947--2955}, + year={2024} } ``` -## License - -This project is released under the [Apache 2.0 license](LICENSE). +## Contact -## Projects in OpenMMLab +This repo is currently maintained by Wooju Lee ([@WoojuLee24](https://github.com/WoojuLee24)) and Dasol Hong ([@dazory](https://github.com/dazory)). -- [MMEngine](https://github.com/open-mmlab/mmengine): OpenMMLab foundational library for training deep learning models. -- [MMCV](https://github.com/open-mmlab/mmcv): OpenMMLab foundational library for computer vision. -- [MMPreTrain](https://github.com/open-mmlab/mmpretrain): OpenMMLab pre-training toolbox and benchmark. -- [MMagic](https://github.com/open-mmlab/mmagic): Open**MM**Lab **A**dvanced, **G**enerative and **I**ntelligent **C**reation toolbox. -- [MMDetection](https://github.com/open-mmlab/mmdetection): OpenMMLab detection toolbox and benchmark. -- [MMDetection3D](https://github.com/open-mmlab/mmdetection3d): OpenMMLab's next-generation platform for general 3D object detection. -- [MMRotate](https://github.com/open-mmlab/mmrotate): OpenMMLab rotated object detection toolbox and benchmark. -- [MMYOLO](https://github.com/open-mmlab/mmyolo): OpenMMLab YOLO series toolbox and benchmark. -- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab semantic segmentation toolbox and benchmark. -- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab text detection, recognition, and understanding toolbox. -- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab pose estimation toolbox and benchmark. -- [MMHuman3D](https://github.com/open-mmlab/mmhuman3d): OpenMMLab 3D human parametric model toolbox and benchmark. -- [MMSelfSup](https://github.com/open-mmlab/mmselfsup): OpenMMLab self-supervised learning toolbox and benchmark. -- [MMRazor](https://github.com/open-mmlab/mmrazor): OpenMMLab model compression toolbox and benchmark. -- [MMFewShot](https://github.com/open-mmlab/mmfewshot): OpenMMLab fewshot learning toolbox and benchmark. -- [MMAction2](https://github.com/open-mmlab/mmaction2): OpenMMLab's next-generation action understanding toolbox and benchmark. -- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab video perception toolbox and benchmark. -- [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab optical flow toolbox and benchmark. -- [MMEditing](https://github.com/open-mmlab/mmediting): OpenMMLab image and video editing toolbox. -- [MMGeneration](https://github.com/open-mmlab/mmgeneration): OpenMMLab image and video generative models toolbox. -- [MMDeploy](https://github.com/open-mmlab/mmdeploy): OpenMMLab model deployment framework. -- [MIM](https://github.com/open-mmlab/mim): MIM installs OpenMMLab packages. -- [MMEval](https://github.com/open-mmlab/mmeval): A unified evaluation library for multiple machine learning libraries. -- [Playground](https://github.com/open-mmlab/playground): A central hub for gathering and showcasing amazing projects built upon OpenMMLab. +For questions regarding mmdetection please visit the [official repository](https://github.com/open-mmlab/mmdetection). diff --git a/mmdet/datasets/transforms/oa_mix.py b/mmdet/datasets/transforms/oa_mix.py old mode 100644 new mode 100755 index 2da1c7f590d..f03cc6b00c7 --- a/mmdet/datasets/transforms/oa_mix.py +++ b/mmdet/datasets/transforms/oa_mix.py @@ -1,7 +1,8 @@ -import cv2 -import numpy as np +# Copyright (c) OpenMMLab. All rights reserved. from typing import List, Tuple +import cv2 +import numpy as np from mmcv.transforms import BaseTransform, Compose from mmdet.registry import TRANSFORMS @@ -10,35 +11,62 @@ def get_transforms(version: str) -> List[Compose]: if version == 'color': transforms = [ - dict(type='AutoContrast'), dict(type='Brightness'), dict(type='Color'), - dict(type='Contrast'), dict(type='Equalize'), dict(type='Invert4Mix'), - dict(type='Posterize'), dict(type='Sharpness'), dict(type='Solarize'), + dict(type='AutoContrast'), + dict(type='Brightness'), + dict(type='Color'), + dict(type='Contrast'), + dict(type='Equalize'), + dict(type='Invert4Mix'), + dict(type='Posterize'), + dict(type='Sharpness'), + dict(type='Solarize'), dict(type='SolarizeAdd') ] elif version == 'geo': transforms = [ - dict(type='BgShearX'), dict(type='BgShearY'), dict(type='BgRotate'), - dict(type='BgTranslateX'), dict(type='BgTranslateY'), - dict(type='BBoxShearX'), dict(type='BBoxShearY'), dict(type='BBoxRotate'), - dict(type='BBoxTranslateX'), dict(type='BBoxTranslateY'), + dict(type='BgShearX'), + dict(type='BgShearY'), + dict(type='BgRotate'), + dict(type='BgTranslateX'), + dict(type='BgTranslateY'), + dict(type='BBoxShearX'), + dict(type='BBoxShearY'), + dict(type='BBoxRotate'), + dict(type='BBoxTranslateX'), + dict(type='BBoxTranslateY'), ] elif version == 'oamix': transforms = [ - dict(type='AutoContrast'), dict(type='Brightness'), dict(type='Color'), - dict(type='Contrast'), dict(type='Equalize'), dict(type='Invert4Mix'), - dict(type='Posterize'), dict(type='Sharpness'), - dict(type='BgShearX'), dict(type='BgShearY'), dict(type='BgRotate'), - dict(type='BgTranslateX'), dict(type='BgTranslateY'), - dict(type='BBoxShearX'), dict(type='BBoxShearY'), dict(type='BBoxRotate'), - dict(type='BBoxTranslateX'), dict(type='BBoxTranslateY'), + dict(type='AutoContrast'), + dict(type='Brightness'), + dict(type='Color'), + dict(type='Contrast'), + dict(type='Equalize'), + dict(type='Invert4Mix'), + dict(type='Posterize'), + dict(type='Sharpness'), + dict(type='BgShearX'), + dict(type='BgShearY'), + dict(type='BgRotate'), + dict(type='BgTranslateX'), + dict(type='BgTranslateY'), + dict(type='BBoxShearX'), + dict(type='BBoxShearY'), + dict(type='BBoxRotate'), + dict(type='BBoxTranslateX'), + dict(type='BBoxTranslateY'), ] else: - raise TypeError(f"Invalid version: {version}. Please add the version to the get_transforms function.") + raise TypeError( + f'Invalid version: {version}. ' + f'Please add the version to the get_transforms function.') transforms = [Compose(transforms) for transforms in transforms] return transforms -def bbox_overlaps_np(bboxes1: np.ndarray, bboxes2: np.ndarray, eps: float = 1e-6) -> np.ndarray: +def bbox_overlaps_np(bboxes1: np.ndarray, + bboxes2: np.ndarray, + eps: float = 1e-6) -> np.ndarray: """Calculate overlap between two set of bboxes. Args: @@ -87,7 +115,8 @@ def bbox_overlaps_np(bboxes1: np.ndarray, bboxes2: np.ndarray, eps: float = 1e-6 @TRANSFORMS.register_module() class OAMix(BaseTransform): - r"""Data augmentation method in `Object-Aware Domain Generalization for Object Detection + r"""Data augmentation method in + `Object-Aware Domain Generalization for Object Detection `_. Refer to https://github.com/woojulee24/OA-DG for implementation details. @@ -105,11 +134,14 @@ class OAMix(BaseTransform): Args: version (str): The version of the augmentation method. Defaults to 'oamix'. - aug_prob_coeff (float): The coefficient of the augmentation probability. + aug_prob_coeff (float): + The coefficient of the augmentation probability. Defaults to 1.0. - mixture_width (int): The number of augmentation operations in the mixture. + mixture_width (int): + The number of augmentation operations in the mixture. Defaults to 3. - mixture_depth (int): The depth of augmentation operations in the mixture. + mixture_depth (int): + The depth of augmentation operations in the mixture. If mixture_depth is -1, the depth is randomly sampled from [1, 4]. Defaults to -1. box_scale (tuple): The scale of the random bounding boxes. @@ -121,8 +153,9 @@ class OAMix(BaseTransform): score_thresh (float): The threshold of the saliency score. Defaults to 10. """ + def __init__(self, - version: str = "oamix", + version: str = 'oamix', aug_prob_coeff: float = 1.0, mixture_width: int = 3, mixture_depth: int = -1, @@ -130,15 +163,27 @@ def __init__(self, box_ratio: tuple = (3, 0.33), sigma_ratio: float = 0.3, score_thresh: float = 10.0) -> None: - assert version in ['color', 'geo', 'oamix'], "The version should be either 'color', 'geo', or 'oamix'." \ - "Please add the version to the get_transforms function." - assert aug_prob_coeff > 0, "The augmentation probability coefficient should be greater than 0." - assert isinstance(mixture_width, int) and mixture_width > 0, "The mixture width should be greater than 0." - assert isinstance(mixture_depth, int) and mixture_depth >= -1, "The mixture depth should be greater than or equal to -1." - assert isinstance(box_scale, tuple) and len(box_scale) == 2, "The box scale should be a tuple of 2 elements." - assert isinstance(box_ratio, tuple) and len(box_ratio) == 2, "The box ratio should be a tuple of 2 elements." - assert 0 <= sigma_ratio <= 1, "The sigma ratio should be in the range [0, 1]." - assert score_thresh >= 0, "The score threshold should be greater than or equal to 0." + assert version in ['color', 'geo', 'oamix'], \ + "The version should be either 'color', 'geo', or 'oamix'."\ + 'Please add the version to the get_transforms function.' + assert aug_prob_coeff > 0, \ + 'The augmentation probability coefficient ' \ + 'should be greater than 0.' + assert isinstance( + mixture_width, int + ) and mixture_width > 0, 'The mixture width should be greater than 0.' + assert isinstance( + mixture_depth, int + ) and mixture_depth >= -1, \ + 'The mixture depth should be greater than or equal to -1.' + assert isinstance(box_scale, tuple) and len( + box_scale) == 2, 'The box scale should be a tuple of 2 elements.' + assert isinstance(box_ratio, tuple) and len( + box_ratio) == 2, 'The box ratio should be a tuple of 2 elements.' + assert 0 <= sigma_ratio <= 1, \ + 'The sigma ratio should be in the range [0, 1].' + assert score_thresh >= 0, \ + 'The score threshold should be greater than or equal to 0.' super(OAMix, self).__init__() self.version = version @@ -152,7 +197,8 @@ def __init__(self, self.score_thresh = score_thresh def transform(self, results) -> dict: - """ The transform function. """ + self._test_transformations() + """The transform function.""" img = results['img'] gt_bboxes = results['gt_bboxes'].numpy() gt_masks = self.get_masks(gt_bboxes, img.shape, use_blur=True) @@ -161,23 +207,28 @@ def transform(self, results) -> dict: return results - def oamix(self, img_orig: np.ndarray, gt_bboxes: np.ndarray, gt_masks: List[np.ndarray]) -> np.ndarray: - ws = np.float32(np.random.dirichlet([self.aug_prob_coeff] * self.mixture_width)) + def oamix(self, img_orig: np.ndarray, gt_bboxes: np.ndarray, + gt_masks: List[np.ndarray]) -> np.ndarray: + ws = np.float32( + np.random.dirichlet([self.aug_prob_coeff] * self.mixture_width)) img_mix = np.zeros_like(img_orig, dtype=np.float32) for i in range(self.mixture_width): - """ Multi-level transformation """ - depth = self.mixture_depth if self.mixture_depth > 0 else np.random.randint(1, 4) + """Multi-level transformation.""" + depth = self.mixture_depth \ + if self.mixture_depth > 0 else np.random.randint(1, 4) img_aug = img_orig.copy() for _ in range(depth): - img_aug = self.multilivel_transform(img_aug, gt_bboxes, gt_masks) + img_aug = self.multilivel_transform(img_aug, gt_bboxes, + gt_masks) img_mix += ws[i] * img_aug - """ Object-aware mixing """ - img_oamix = self.object_aware_mixing(img_orig, img_mix, gt_bboxes, gt_masks) + img_oamix = self.object_aware_mixing(img_orig, img_mix, gt_bboxes, + gt_masks) return np.asarray(img_oamix, dtype=img_orig.dtype) - def multilivel_transform(self, img: np.ndarray, gt_bboxes: np.ndarray, gt_masks: List[np.ndarray]) -> np.ndarray: + def multilivel_transform(self, img: np.ndarray, gt_bboxes: np.ndarray, + gt_masks: List[np.ndarray]) -> np.ndarray: rand_bboxes = self.get_random_bboxes(img.shape, num_bboxes=(1, 3)) rand_masks = self.get_masks(rand_bboxes, img.shape) @@ -186,13 +237,17 @@ def multilivel_transform(self, img: np.ndarray, gt_bboxes: np.ndarray, gt_masks: img_tmp += rand_mask * self.aug(img, gt_bboxes, gt_masks) union_mask = np.max(rand_masks, axis=0) img_aug = np.asarray( - img_tmp + (1.0 - union_mask) * self.aug(img, gt_bboxes, gt_masks), dtype=img.dtype - ) + img_tmp + (1.0 - union_mask) * self.aug(img, gt_bboxes, gt_masks), + dtype=img.dtype) return img_aug - def get_random_bboxes(self, img_shape: Tuple, num_bboxes: Tuple[int, int], - max_iters: int = 50, eps: float = 1e-6) -> np.ndarray: - assert max_iters > 0, "The maximum number of iterations should be greater than 0." + def get_random_bboxes(self, + img_shape: Tuple, + num_bboxes: Tuple[int, int], + max_iters: int = 50, + eps: float = 1e-6) -> np.ndarray: + assert max_iters > 0, \ + 'The maximum number of iterations should be greater than 0.' h_img, w_img, _ = img_shape num_target_bboxes = np.random.randint(*num_bboxes) @@ -206,7 +261,7 @@ def get_random_bboxes(self, img_shape: Tuple, num_bboxes: Tuple[int, int], height = scale * h_img width = height * aspect_ratio if width > w_img or height > h_img: - continue # Invalid bbox (out of the image) + continue # Invalid bbox (out of the image) xmin = np.random.uniform(0, w_img - width) ymin = np.random.uniform(0, h_img - height) @@ -216,14 +271,17 @@ def get_random_bboxes(self, img_shape: Tuple, num_bboxes: Tuple[int, int], rand_bbox = np.array([[xmin, ymin, xmax, ymax]]) ious = bbox_overlaps_np(rand_bbox, rand_bboxes) if np.sum(ious) > eps: - continue # Invalid bbox (overlapping with existing bboxes) + continue # Invalid bbox (overlapping with existing bboxes) rand_bboxes = np.concatenate([rand_bboxes, rand_bbox], axis=0) return rand_bboxes - def get_masks(self, bboxes: np.ndarray, img_shape: tuple, use_blur: bool = False) -> List[np.ndarray]: - """ Get the masks of the bounding boxes. """ + def get_masks(self, + bboxes: np.ndarray, + img_shape: tuple, + use_blur: bool = False) -> List[np.ndarray]: + """Get the masks of the bounding boxes.""" mask_list = [] for bbox in bboxes: if len(bbox.shape) == 2 and bbox.shape[1] == 4: @@ -234,20 +292,29 @@ def get_masks(self, bboxes: np.ndarray, img_shape: tuple, use_blur: bool = False sigma_x = (bbox[2] - bbox[0]) * self.sigma_ratio / 3 * 2 sigma_y = (bbox[3] - bbox[1]) * self.sigma_ratio / 3 * 2 if not (sigma_x <= 0 or sigma_y <= 0): - mask = cv2.GaussianBlur(mask, (0, 0), sigmaX=sigma_x.item(), sigmaY=sigma_y.item()) - mask = cv2.resize(mask, (img_shape[1], img_shape[0]), interpolation=cv2.INTER_LINEAR) + mask = cv2.GaussianBlur( + mask, (0, 0), + sigmaX=sigma_x.item(), + sigmaY=sigma_y.item()) + mask = cv2.resize( + mask, (img_shape[1], img_shape[0]), + interpolation=cv2.INTER_LINEAR) mask_list.append(mask) return mask_list - def aug(self, img: np.ndarray, gt_bboxes: np.ndarray, gt_masks: List[np.ndarray]): + def aug(self, img: np.ndarray, gt_bboxes: np.ndarray, + gt_masks: List[np.ndarray]): op = np.random.choice(self.transforms) - op_kwargs = {'img': img, 'bboxes': gt_bboxes, 'masks': gt_masks, 'img_shape': img.shape} + op_kwargs = { + 'img': img, + 'bboxes': gt_bboxes, + 'masks': gt_masks, + 'img_shape': img.shape + } img_aug = op(op_kwargs)['img'] return img_aug - def object_aware_mixing(self, - img_orig: np.ndarray, - img_mix: np.ndarray, + def object_aware_mixing(self, img_orig: np.ndarray, img_mix: np.ndarray, gt_bboxes: np.ndarray, gt_masks: List[np.ndarray]) -> np.ndarray: gt_scores = self.get_saliency_scores(img_orig, gt_bboxes) @@ -255,11 +322,13 @@ def object_aware_mixing(self, target_indices = gt_scores < self.score_thresh target_bboxes = gt_bboxes[target_indices] target_masks = [gt_masks[i] for i in np.where(target_indices)[0]] - target_m = np.random.uniform(0.0, 0.5, len(target_bboxes)).astype(np.float32) + target_m = np.random.uniform(0.0, 0.5, + len(target_bboxes)).astype(np.float32) rand_bboxes = self.get_random_bboxes(img_orig.shape, num_bboxes=(3, 5)) rand_masks = self.get_masks(rand_bboxes, img_orig.shape, use_blur=True) - rand_m = np.random.uniform(0.0, 1.0, len(rand_bboxes)).astype(np.float32) + rand_m = np.random.uniform(0.0, 1.0, + len(rand_bboxes)).astype(np.float32) target_bboxes = np.vstack((target_bboxes, rand_bboxes)) target_masks.extend(rand_masks) @@ -297,89 +366,89 @@ def get_saliency_scores(img: np.ndarray, bboxes: np.ndarray) -> np.ndarray: saliency = cv2.saliency.StaticSaliencySpectralResidual_create() (success, saliency_map) = saliency.computeSaliency(bbox_img) - score = np.mean((saliency_map * 255).astype("uint8")) + score = np.mean((saliency_map * 255).astype('uint8')) saliency_scores.append(score) return np.asarray(saliency_scores, dtype=np.uint8) def __repr__(self) -> str: repr_str = self.__class__.__name__ - repr_str += f'(version={self.version}, aug_prob_coeff={self.aug_prob_coeff}, ' \ - f'mixture_width={self.mixture_width}, mixture_depth={self.mixture_depth}, ' \ - f'box_scale={self.box_scale}, box_ratio={self.box_ratio}, ' \ - f'sigma_ratio={self.sigma_ratio}, score_thresh={self.score_thresh})' + repr_str += f'(version={self.version}, ' \ + f'aug_prob_coeff={self.aug_prob_coeff}, ' \ + f'mixture_width={self.mixture_width}, ' \ + f'mixture_depth={self.mixture_depth}, ' \ + f'box_scale={self.box_scale}, ' \ + f'box_ratio={self.box_ratio}, ' \ + f'sigma_ratio={self.sigma_ratio}, ' \ + f'score_thresh={self.score_thresh})' return repr_str @staticmethod - def _load_example_data() -> Tuple[np.ndarray, np.ndarray]: - img = cv2.imread("../demo/demo.jpg") - gt_bboxes = np.array( - [[609.2460327148438, 111.9759292602539, 635.9223022460938, 137.42437744140625], - [480.33782958984375, 110.44952392578125, 521.1831665039062, 129.6164093017578], - [295.34356689453125, 116.82196807861328, 379.7244873046875, 149.78955078125], - [219.83975219726562, 177.6780548095703, 455.0238952636719, 382.3981628417969], - [0.2255704551935196, 111.30818176269531, 62.27484130859375, 145.1354522705078], - [191.24462890625, 108.73335266113281, 297.60186767578125, 155.57919311523438], - [431.61810302734375, 105.31916046142578, 482.30120849609375, 132.21238708496094], - [589.4951171875, 111.1348648071289, 616.8546752929688, 126.33065032958984], - [167.97503662109375, 106.92251586914062, 211.0978546142578, 140.34495544433594], - [270.0672912597656, 104.84465789794922, 326.28662109375, 128.14691162109375], - [395.87152099609375, 111.33557891845703, 433.2410583496094, 132.824462890625], - [60.50261306762695, 94.38719940185547, 85.39842224121094, 105.71919250488281], - [373.8731384277344, 136.39341735839844, 434.0091857910156, 187.2471466064453], - [141.07984924316406, 96.27764129638672, 166.647705078125, 105.06587982177734], - [224.4158477783203, 97.63524627685547, 249.95281982421875, 107.63406372070312], - [556.1692504882812, 110.58447265625, 588.9140014648438, 127.4863052368164], - [77.04759979248047, 90.36402130126953, 97.74053192138672, 98.54667663574219]] - ) + def _load_random_data() -> Tuple[np.ndarray, np.ndarray]: + img = np.random.randint(0, 256, (427, 640, 3)).astype(np.uint8) + gt_bboxes = np.random.randn(3, 4).astype(np.float32) return img, gt_bboxes def _test_transformations(self) -> None: - img, gt_bboxes = self._load_example_data() + img, gt_bboxes = self._load_random_data() gt_masks = self.get_masks(gt_bboxes, img.shape, use_blur=True) - transform_type_list = ['AutoContrast', 'Brightness', 'Color', 'Contrast', 'Equalize', 'Invert4Mix', - 'Posterize', 'Sharpness', 'Solarize', 'SolarizeAdd', - 'BgShearX', 'BgShearY', 'BgRotate', 'BgTranslateX', 'BgTranslateY', - 'BBoxShearX', 'BBoxShearY', 'BBoxRotate', 'BBoxTranslateX', 'BBoxTranslateY'] + transform_type_list = [ + 'AutoContrast', 'Brightness', 'Color', 'Contrast', 'Equalize', + 'Invert4Mix', 'Posterize', 'Sharpness', 'Solarize', 'SolarizeAdd', + 'BgShearX', 'BgShearY', 'BgRotate', 'BgTranslateX', 'BgTranslateY', + 'BBoxShearX', 'BBoxShearY', 'BBoxRotate', 'BBoxTranslateX', + 'BBoxTranslateY' + ] for type in transform_type_list: op = Compose(dict(type=type)) - op_kwargs = {'img': img, 'bboxes': gt_bboxes, 'masks': gt_masks, 'img_shape': img.shape} + op_kwargs = { + 'img': img, + 'bboxes': gt_bboxes, + 'masks': gt_masks, + 'img_shape': img.shape + } _ = op(op_kwargs)['img'] return def _test_multilevel_transformations(self) -> None: - img_orig, gt_bboxes = self._load_example_data() + img_orig, gt_bboxes = self._load_random_data() gt_masks = self.get_masks(gt_bboxes, img_orig.shape, use_blur=True) for idx in range(10): - ws = np.float32(np.random.dirichlet([self.aug_prob_coeff] * self.mixture_width)) + ws = np.float32( + np.random.dirichlet([self.aug_prob_coeff] * + self.mixture_width)) img_mix = np.zeros_like(img_orig, dtype=np.float32) for i in range(self.mixture_width): - """ Multi-level transformation """ - depth = self.mixture_depth if self.mixture_depth > 0 else np.random.randint(1, 4) + """Multi-level transformation.""" + depth = self.mixture_depth \ + if self.mixture_depth > 0 else np.random.randint(1, 4) img_aug = img_orig.copy() for _ in range(depth): - img_aug = self.multilivel_transform(img_aug, gt_bboxes, gt_masks) + img_aug = self.multilivel_transform( + img_aug, gt_bboxes, gt_masks) img_mix += ws[i] * img_aug return def _test_objectaware_mixing(self) -> None: - img_orig, gt_bboxes = self._load_example_data() + img_orig, gt_bboxes = self._load_random_data() gt_masks = self.get_masks(gt_bboxes, img_orig.shape, use_blur=True) - ws = np.float32(np.random.dirichlet([self.aug_prob_coeff] * self.mixture_width)) + ws = np.float32( + np.random.dirichlet([self.aug_prob_coeff] * self.mixture_width)) img_mix = np.zeros_like(img_orig, dtype=np.float32) for i in range(self.mixture_width): - """ Multi-level transformation """ - depth = self.mixture_depth if self.mixture_depth > 0 else np.random.randint(1, 4) + """Multi-level transformation.""" + depth = self.mixture_depth if \ + self.mixture_depth > 0 else np.random.randint(1, 4) img_aug = img_orig.copy() for _ in range(depth): - img_aug = self.multilivel_transform(img_aug, gt_bboxes, gt_masks) + img_aug = self.multilivel_transform(img_aug, gt_bboxes, + gt_masks) img_mix += ws[i] * img_aug - """ Object-aware mixing """ - img_oamix = self.object_aware_mixing(img_orig, img_mix, gt_bboxes, gt_masks) + self.object_aware_mixing(img_orig, img_mix, gt_bboxes, gt_masks) return diff --git a/resources/oamix_examples.gif b/resources/oamix_examples.gif new file mode 100644 index 0000000000000000000000000000000000000000..a98919c7ee771b87b21c124283e1b156213a5f68 GIT binary patch literal 5278753 zcmWifXa1-SSII0VKx`$js1MmvQ~7dkq7+1dyA zxOhAJcsmCLdwO~L`uqAt`h?B&^$iI3508w9iVO=3R0+jN;h{_1y!{f`0rBF{IG4yI z$EZxluzB3Tc<-=y&zT8cVJTkGDPA!t0g;IT(F>$AmEkd&VF7c)0v9+XE_R+f-$7Mm zAD0>!pX?jIAW*s3Z^07(xU|s3BzgRT(3H%u)FojlnbE0BqT?6K7v#odW=CeGM5N^V zEn4B9x;!#-c}&&{c}`i#^5U6=8>3clkgY6mkBm+Xib!@2Qe{Ub%}kgZHFr_CYF_xf z!o(zH^5TR=If-*q79}rSyeK;>Gc|QZwkmr?Vot&0j8(}6MarU$%CarW6-A4SH!j+^ zdGY$y3z8SFShQqCTIRZ(g2J4_4H>I7D~bz?*B5Qvv~tC&jjPv}6>r$Ge(Tmv<(t;8 zUbQl(uxj(d%B}16En3!^ATO_)SyD4&^-lZEmCDUki%YksZ>cIMuP)2quC8i~YU=RQ z*Crd9lPYVoYMZl~_U1G-rnhutH+H6V9?Pk&U9r7+-PpW+P3PXC z&VA)4noHU`R`2axv9Dv}&hAy6T}9oei#v{$9qB4RbF}>U>58L08_%5G()0DI!@ILC z_iVXwdhxfHH&|}1`~HUC!gY1Un<_VKtz2DFpS5~lXja>1cAx04KR@vGl`Az@ukE>ibN98u!`CnO4iBAwaQE1YACI3rf35MGTW7D{K5^^$ z?K`*LzI*@pnf2QFLHEB=GzI|R{y)yVgvU`~SRN^#*q zE6#Ed_=^nV+RDl}kMfp0ML)wa)GOaexDW8WlSATrXAHy{9YaIJj?(1@y}P1{=@=5P zPGE(Uu4w5SH|ZVia+(G~EPY-hKP+)@D9l&U7FE>lvW)bxIT~D$B`a3UOKpn|%~h*G z9N;08e^#1Yjq5wy8Yk1Xn8qu=j|{08X*pOz2TT{K$_fuQ z&^h*)%3?wx2>nSQ16ELlRdPbr8Vc0eLwj+;2y0=6b66Old}vIbEr?-42;vf9D~eHG zx>%+Gs0}K^)#tQqXhIUJ=}*uIf^sH}l3@>&X&0mh7OSa`jttRyX8~Nn;PSGIJY@cnQmC|C;Ky@q|_e$?;@XI1%SuS1JOAxvMhV1>S3@w1X*#r%Io8sTKw~3C0GaydEsNO_3P75FUgz7&B}d^&gA_dN^d0W@ zp!3z)MW45pj(F5IzI_GQ{&XtVi5W(Cu%s>~EoaZy(kCl9tC=~q;Wy0$FHZHLkrt~D z(8oF4BnIH>1Q{M{7gMI7$*n5YMIX_NdpIXNeyxJLbUcocG+M&)il~qcFpg}t{y$rI z*gU|Cv}D@F&i*b;g{^-hi?^r8ZNL)$T7>-yI}WOJT=K3_yi^rcFsyh|T5<(@$0i3- zt=Hp&FZ{{oUrx#)Eo7-{9<~_r*Q+U%A6E6V$?x*2K1MocbZsn++q9T&|0Gg{3UkeC z=|WVZ9&#MPNT{&rFopU)86r-7;cq5jhDeQ1=85sCb&>r25eSfyQxyI88J0deM`NK@ zL-VXSWCP~GtTg65pOw5Ae1BI5WQl?#VuJe8S0*MrybJ_@*lj`v6_H_x`2*e)R9img~lrs3$3s!1fi)9QFdE^#b>`$mjyzc#;*};*e2^wtoonq zS8`6ZOW>C!@8!i47`lf*j#OQJKC#d5uZQ_D9z=^s1|$HY`A8vbSPjuraoX$&f~B#d zc1MtdYb-eMOl3hUD}E zK`oPnI-4AN-xQE+>u+%VNnAR^LfT3?nr;5(fSE)Cx16H;S*^ulD~6K0ZV6W|coQXPVuS%Qb0{YoL&MggxSN8xr;Igfhmb z9r%ZumDXhUWYK~n5qvz(*)(;Oe+wMS4KT^Q#gcxI(_^eb04Ld5ro(YG(&QT6e!SIOJtAiEhMy1bQunD zHDI2oupgD;kVtI;Jj-gf*b4JpKvkO~uQud4k#ST)D?FtwS_*WLCM)#eWZG8qaI39J z60&~P{J#gw9{rP=h<;rPlG4leWj;`5Z+_#Vb(gfZ|C$bY7#6@%qxQnnwAl7yPnyV_ z2L}>$Lyvy0YJB-)#GeiuPwZ`ne`lx}Px??}Z)`Yv@H)#rpAn7?#Z7vU3f@rh|Ol=8(@t#|F;}w8l#kbXoLZKW}pLE+>rye)VJ!?|mseF=*N&P(qTQ4$@gqaQY`y zF}W-&1M@t#YDR?Ym%X-H^053$ThKZEon^7RTE|;`GEPdFv0(1@ZIb=+2?by92kC@! z(fU!+xKgHfLlw;Esbe~`C_F2|%M*M1Dv^2>^m(#%y1kf{n6G)|>@~{tlR&6^{Ht^6 zn;?;j%zd$u!1Ms_52fghD}Yzs1Ddy5Q7*9_c&tj?dflU&qoQDbOBnFz`l@uA26ozU zJLu-aqL)r)SWt*_C#bpdc0P!V`5l!|1pWHss<1M^3@D^}A~g&Ljd}B*2Z?PymW#K- zaa1&d7gU&ZyfCPwiIcW9*#cjFzyv|&%g8;+#RpB+Qw?NJiOl?vA33$RFRVazU!!yW>CFYgpspNdrJzatU3 zgjaAEpYuf;5YD)O#^>X%4iYWTq<+kjkbk{8yO(EO2-O`5rp1pFvKEAXIZ4L>4Cuzk zZ_fR7;d71HvTVT_A343Zkr}azprOA-e0$8u^mr027?!5ToKMpaGJO~Z0KQriUS*zWwx33dKwrh;l&SU^yM~3Z%HVk(|uY9GZ&^ z`geB6Qaf`FyRbs$f<1!&bc$UbGWm>jZJX(C&xF69c%Q7Va2Kd8N*+#3=XmA$B~Bbk zm+o&cJX6!R?8fbbOKL>2guSBWGv@X2Q{0h6ix?RRa~_4fpNk@JG=!qp@@U5F?hDwZ)qXQVldOA;EyvoLgyy6AOTQM4sEv?TQXGi=R9+5HZ2 zeM9oO4_{2{E9N6J;?>Av4b6Fc8eQeP?h1(Z#cv;qJr2#bK8G(LTTBq-QwI>i~n#*w%xsHo(XC=5iD16zmRy;X&5M2;A#pCqTn$S?4Aoo z8<^5WQ~TTaiJR^?fGIgn+9{Uz#r_Y%vQSp)Cs94kHfice4HtTm_Lt!7`y zSmHt0eyL^#%787X-S)!M4u;-5m$0imc$CjHIHqMrAif87P{UFQ>ZgVSfMGU3(#-fo z6)dG7EmAuq^!)EP{^55Z;G?` zqF`GSbOl0d03?$+bLkM@Byky@^T!$b8(KDdkISV)fnS$8^B-o~;BYVj_^;eko29$R z?JJt}J$xvka@HY%(?0i^;acxah|$Be+KXCPtB(G3(z*R0V5%mAgJXrcfOFfDx?6_s z0V$<2#vYs>DucXLbdd^n%7D&=g!7u=SQ!##rOisWv&)0(`od?K3K%oCrJc=(-LXVT z{xovBK*rDl)<7mQ)Fhl^o$q`%$N>~0C&ecc)9jXG)g%!yrPogDkb{d?`e0FtqB$7D zkA&hhv>-?*;|oVmA<AXJ^}5Ni)KIC9pp zWS4znT>8T^PMFFkK@Hya>Ca650!MlH_rEJYek#fr19^a_AE!Tk1>e?W+@NQZTqPtRzVk^A}XO?2%wuSC5n9T%@YdamCRlr5I2%~(w%rHv=pIQLg+@LG*qyXQ0p2DwV2pi`*31`!jm5X+2B zhB)R`Vi*cmp|6!2EM0+XuRBH2rUh$2FktYE^iyO#+bXST+;1LmtPwu-8PLsBuzP>tU+-d^JRH}(b78CC2tnJQC3C+j@PssQ zmfdI440pN&cH0f-s`-01vO-MUCcq@r&>RW0N5$JJu~}oqY!*VRLZAm?NLLEs81%L= zHA&5|EBPW!7cjG`NnD8?EfapeOXT1U%~COBX2d=Ldp^DkTL48va0-AWgy1uJN{A8u zagQ(|eFmdMW^4UAH4llVPB)6nuJY%je*Ah-*;YGgDS@;OWR z{#WtJ-9oR1Rex0r9V|ehhJ{TI??W&|!zdvN>iZ^%#6_6R2FTkIwqp^p?S0zMB7SYi z>=m096-|CA?jyiAd`-r6My1Jn z+Oz1bM;)7=fdG}wq`5l4I`foA16yKKp#iQY*x3jWj5nyMTWzEVjq{#M)WAw7HO|Ks zCF=Fyb_?uL+wr{*Sk@l6VOh@FZ=LqO-0We*157ZBnC+{Afrcry(g(*GLniJ&xa0jq z@2!VsuP)81-q7Gq$4}jx#xf<|J_r1a^kNAEzXlIC13HS@(NNdD+c{Q-huLG~B|Sqs zhk>`*Sm;yEOdSdZ+;gm+dkBtSQQ#fps} zR>hiQlD%!rKk5}hwi%mI{am&p&Ws=g@^n}bsfHHw(UqFkp-e=c zo;>L|1DXcrXO~Qmhlw7UO9qDMH5HmRKMZnG&EdCe_Cf03 z6^>8RVT+Y?r|kp>MK)`!bs^ z3EMPHcAnLClgR@$=eem+S1U-9z)y;3|CBQ}nP9%zk#V)c$pd&{S|a00KC7T1z9_c<5_{X4rceOEX!)EW7?&bvnd)BKT z-=EV(!&@4Z%6gB}KSq6@;45l(l!ri}nh?2~^;pf?gRyp`v5Ycgj@mBNjM(Jymw8C| z<{~!vg@eVbtgc}0MZ2H+u`lnY#kMtnIu$%;R9M*vAx21n5&uM93Bl2rKEb_AbP+)2 zD1`M5;@Z40w}J3kdHh)BHAls~z>|pcy8@c$H^NC7dQZn+Py}UDw3QmxdX-(pZb4Da zwX_=HjsJwvX67wDlXkSD%So{qKynUFmM-Cz~biQ-r(%4$m{t<3J`sdk7g6@LRq{m(By|zf97p`{t#w-Fa~4vIth3B%gFUkPD^jQ z?i_G@lGKnHQ zSuB)AW9$v5K40IdzuSD{JG*}xu@f&GjwE8K3Tapg()*oV6qIg#v8w4|P_{iBL-FbO zj8L`NDg7II;Uv29tT3=o5Im5^un29i@1qG_oOQkX*y*UZWP(eXb6x5BfBAu5C1(u1o0%`VK)Xtx zG}-=Zv3+TVv||Zj8T*4U`ex^)FlO@%@Sfq!8{>={Cfk0(cAeF8t=0CiEGn%h&fAyq z7-PsxaEXL*S4C$6#3d^7?%7<46=s~GOYxeSRJ2p$(zM^~nJTy#uMt`QJu`_wYh~;$ zvR67MmzVc_ekhHz9Q4^$5}o}5Ea=IL^W9;exdb=T>rTVQwOPn`W%~k1N-?&^LbCxJ zXf}Ou52}guhbW&w3I=_E(DFp{OvL}WZ#tKYGx=Hqrla;!jmCpQAS|mbYxEg}0$VIb zV>KysIyM|8k*f2-tQoMheB3&!t6^h-6`1-U!uMJfCho$!GZ@KkZNP}@lk3<{y#!?? z7aYxGF*OnuNuX&rV2oHq{jkqof0O0mHE47X2}IKeDmj75+^^BQ4l8l2;GdsHjU|d$ zJy$%lVo(C<^ddj6++N~&s?p6sn>lIL^u-o(8X8ek}X=$!Vm1FX@LS=)3Ug{V*c*D?E z@o7CL6dyJm&$?>7XAf>cWr>51F$%AILkbg0@%$ZeUMu^0mg5T5+L+z!33ZHPZ5Ii# zGt4IRzW;AP=Uj}9>ShD@x+!|QYe)IEuCvo z&#eQ6v6Z4si3x%aq#8P#eM+rEq^cOuX{YWbLvXnVuJhe|UFqIZ5o_(R`ZHt?ti z{rAr5m`Y|OnSSkPDamVE#!3_kBmid0Wu~O*t~d!F3oDe={~rfi#{WqUNTJHfZh0G) z)!o5RW_h8gL?h9QMOC1)%|0i&vu;5KjzZh~hmN$)m|W57Cs1{E5Q*v zJgbv_e5u;~h(lN3iKyU)-`M|XdN1#H3q1LvF|dwOv_;PxT=rYkugMy^t(2Kp^*ZsG zr1Iv+eVky+Y$LDQ5k0czJOY$yI*2V+S2^mV|uPIH6|Um*fgr=iCYvi&>Oo{vrxbJI}qD^i^FNhx&ejF zxmU_J+=35RD`f6oRGz`HiWysD*||O@gjFCjeR^o*9IVo1F#CX$R~Mn%xyk^FMe|hO z>J76da2;%KArK!5qQz70bO6vvpAeF@ngZOyJ#oxsz;}jPy@(_@Gfjqgw0~KGmZA$@ zb)Han_%2}smyU*$ZDBtKTK}Ys4D%>=u zrJ~4>CQ%~(R^n7g0EA4%9~nmAV3L49Wgjdsf~QfYW+~nW`J1%UOjV0{EjVZ1q*`ke zCPTB{5N8j20}*CU@@siOw^soqv)QP|kY7-3l0Y(o+|DcQ`+er{Klokl#IxEH|ytCX|eGcJDLKh|6Df6HvNn1h8 zUpuxHX=IFWl}GS*&h;{``0}Ceh7aO28B@WC8K+oM zg0!*Xj0u@Pl1Fzpu zX!T(0(TrFi@&b8M1jJk6wxY)G($0jwYd*;_EMO> zAP>MJ)~b;40oqT4)h@RsHeZ1DbULPR|74-lPz%jT)yJ+|sEt#`iHoKp0>NEi^W>Ya zq_QiL9KtXu1mChM)F;`w8UwD9WJ<+rk$$5#v?-5=7ZRu9@ndUG>7lhs9RgdwiL~1K z{BdPh!rt}^>|f0^+*~W2o9U>Nn7FZBZ@Z)xV){`l>fPeR(JyfnswhUF8fRph^skv} z;HC_Q1R4N$;QBOo5NwMDOrO;aIXn_1I;rT*=gx9v8U-^e5jdrLT?%;g`_xJ!K>l9< zvm#NwQ1zC634gkbn(xynYztE#T=M98>^wnp^NZBBV5JaL9To{(UZ6J8v&vyR;u{w~ z|5O$@D^LN&N_XZzL;c&fsnLedw1VpzMquhp)KmY~W&s^oXFHJW$%`xKEHhOGs?V`o z)Hc^KEv*VL_k6s~h@cKU-#wFHntln{ve4$QE!BIuZzQl{g0P(}Bfef_vWdpvX>~Ha z67Od=)NIPz<$2)Gm}|jHzay3;cph+)48lHRzp=PD$x`D;yj->U*O1x2)a6EoD{p`N zb^V^zUqOOwF;6F34*_C@Lg3VIyKQ{Wp!ZzG^M8rw`UN^0sfjtmNHFT0Pw-p)`M1zl z?n}%d_u=&-RNV&os|jYIm84CyLlGvZlbTW-R!_z6UF2^976O>i6-h!^9aA#2L;IZC z{g6D?M){{G0H^KlofDGd#X|n|z@ln`ZQ^xue)}@{5_!jX!!kmig8dvVRzG ztJKW$b9Zt}pRD=!QLsotwwup{{@wUZ*7k3*3^Xx+$5GO~MR&q)UkV-<+>54Yst?^` zMAYe(R_*iBb?g}ORB4T-2$8~k*wYwkIaM@MM{+25+;1D5Bp(;%>@(3(M&pF{`Xt_?kim#2hZPo ze2zJ?O9-InNf1U5@JKUjqMS(GoJzFNB|QOje*_g*zRNvqGX{h2T2tNGRErPvG^kdJaLRs zt}D>!B2*?up*C!U@UakfI8=xozR9TyQIiWA*?(;hSIba*e4uuWkV{dK5s63!+s1V| z9AGL)-v;Pe&FgbXArAg#>DWvn%sS{l3U=dr5sFzPhxkln4yLoi_5=>Je72m$Abk9X z&YV!i7MC0psi9R#!8TGLD%1)`cKu(XWda_J-q|IjY%17MLD9#!+!!)D8B9}Oqftaa z*OjKO>s4}`Z3QpbcAWDaG$lECKAQA($A4sVP=XW?hDZWLYLL_m%$&HIcNq@Nh7O}j``?7^>q4$A$@{h+*s64aU3!ri#1MLCnN~t*IbvJS zK|LgidR(~kQTYy5VST^iI2VJ%i33i9q)iur_5#n5EA?@A-=*8IRQoSF9awhJlG0L{ z+%l=VDc6NHoj=#kdz-#-aTqSQKEf-cR5_@2Z^4#dR{vy>-BWA!{7losmpUtU<2aw-Bl+jp6%8|fcSj;+mt zve-mt6ag^zDzzu4)lk;WE@Owb%1Jf`Y)r(bRN)V$owFtt#EV}S3tEYA=Aj2MB#WZF zPXe|i&c^6K!)K`8?|C6GaY*rJ$UvgdFgquWo=faVbpo<}Y|xh>K~fLx&;FyaswI}z zZjWO0M4ZSY2&7lJNo&9~P=~b!yUzAq9dcvI-ew6YTgCEK)Kp$K9A-9jbYBcT*YV9@ zSGHiNluuihd->a_?gE+L7cFsWP4$Q4^LKc{5(83x<4SBT0jac(YEn}J9wn2uEd*Px z3uW@oW5z%OabCWIBk4`8vs@ViHcZ}ARbc;l($Gs@l>{a&Mb?NJ?qeXvte=S=J2cbA zlUW%cJ3*arc|7LoNOcc1g2_wCI|5X2%X`903E4>C(opR?eE&P(#%%_do=&?fu8+d1 zLwnKh&T;!%wacW+?|Utp-00Ds#}1QGD|W}`Y@4-)EZ$S>7x{tVknS74il67 z5K$6YLU4NwQ{Kxp&(Pi7SUc!mr?%Euf?Ece1xf52cPYsS|ApZSWzGNI337_GcID4w zmqT_kUZhr36hzeB1B)BkHh>KCBXjmGW2;D7SHN@y>2t8QYVi9e8E4%o*f16JYRE{# znX!tN6InIeyAOJmAxr$l+DF0Max(W&L&L<^JnM^tS8D12>pw;8*iA{yIdPkD-rs!>o}E4Ogl66j=P!6&d5Ss2{!Yd2n)`B z_DX3jZdx)Ih*s+AD*|oHa}IbK7PK@c355a+WZSPX*uR>hJsoU8}qmF79^#7 zAEy@zmSc~xb472WECu_2A&ef_j;dht;{@bg_`O^9vHc5ndIz<=klK);JC2reI3rHA z$J#g}(0xh$BdP5di93EX`#U9MD02@S|M9?>ocve=Vi%GuQ^e!G=xL*T8{&}2%>2b8 z?m2r`?TC=AhJ1elHZA<1o1q8KUVKk_qEnZVEaiksiZDHCo` zg0{AStHo??78HpM4PHJt>yi+@-eo@qM4F4Mj~cd<68@HkvW8CDpd_H{)V?hHX;EiS ziS3$4h{iFNo8Q^D?X@i|^@1t_WrCt2!e?ZEjWV`FRlCgb2RmgANwl0xt@-zTKS>Za zV@HlG<-`k<;1)3Tv5@)pW8Iw;qH+^f?M~=-oF>@4ca**0Mj$D9`;Kis*8k^^nqcVa z=A*9G$yb9pi5HxRcP(unOgt`8_dm_CFgb10*zvTU_gBvRqTNMy?*i^!`Us$i0|)kuBV;6oKoiUbk>d&DTG!GzPk_|a`XA(rvx(6Fui|-#FFARuCvjQ zo;&z74YEK9ebjzl`**NA*?^sju2uRnvvg)92NXi3C8S!Z2h^KRX=4w1CViAY{oK0g z@4jae{01*J4MXC}buBlokeu1xea=P=RamOOuvTjdb@5ppUE?hgSAdlc2C(E7(IAvtY@i?BpFU$wWy@~2B!HAd;zERB)kPLsxBa65wgjUzxS%$|0RuTT z0`+utnhq#QPnZ7N2VVOe!=S&GfAg{;_Tpt0V!5?F$;;4ScK|aWN=_X#G}pazDu2@G z^nbnBE9rgL;6wm(iPGGCBpE(4 zXsS!hsF0}dyl?Snv3b}Fd6es;l%%9kcWUU{C+53DCtd!``e`J4r#5VDf+O5*kx7^v?B1ou=~d^Kaei>lW0NFW7%)Kkpqe_8xQj2kUiwaQ+!ZM$WzJ z)?(`_)l;=E+n+Zq~ zrrYM!{Q*2sJ&(?)VoL0l1gLQhU&aFWg@Hld>ss;~ck#WQ9DPT~rf&B9Fby97znXyVl)7E+7A_piEZ z^5!DdNMKw(Wf>i)@$sG4v~ZMC=)jtZHB4Y4nfZg(CP(v!JS`o}J1)U`;#>kG@4oI;<_d?Unfl-8Iu znoT?=4h{tfmn-P%yGv49{Y1e(K5A7r?Db^$vh)33K92C@;}27L(s2Oon(8KUP!*1I zE?q$(77w=oh=GGX$X3(SxGpw|DFBeZY)VoZOq43{I1aCvw3Sb&Gh+Nz9w@ul+OPZm zL3~@y+o2|^V*Q8p^Sv9m{92k%mdxABCs#w#?3wQIFe}qc0Zv^M1swEA!!$dU3Tl%w z!D5=X$rwwszo~)t7OrhwfmVDb>7$7|Bv5;d)fj8PM37vF@&}Cy`mPli)Hx>?LOG?G zq1Sfnu1k=1z0DSd%^am9ZqHJW+u6_H;t7veg*Ewbj(!!%5WRTw2Z#bYIL;zjZr#tNZ67-#%$?~0hc{i zSsWttQV|f-PNslpE-E8n^DTXoYq(xxpCi)VvnoS{YkWHEc|krv$NDr{MA~&I5lwhP z9Q&4AM*Nt|ts8+VnfiuAPH_*QoDD!m!kZrS?E7-h%Psw2GQ{dem!^8!xxY$11cP8h z%q?ebn}vjn40(y}-ep(#Pzc{hfyp4%AF;t^qjsS;`mEaR`RYSeh?DG~*63$aA@t;_ zrYD>CNikT?W=dk+WZs&erv1}F=A^D)_OZ%II)PQPbx+=*j7?FzK2|Wc30C>c2KQOu zX49YflApSlWaLfTqx$^3|BG0t^k>?#&so9N+=rK*o&jvn>9`{OzZD=>IQ?H9poM>G z`x8F*zulAJ>>$0(K}0uj7j|D;@oL}ZnIWb=mSVPe7aeU8H*VP&uDPT=zi&Q8=NV@}>}sy%gHWn=p?WV=ifTG5vQ^1gTZ_YVCo$ICpT6tgDyJCvC?(Qlr@3V?tSoeJOX!lp@`d{for%3Y(+0}sh?b#JB{al&xa7;KKReTy{U}7Xy+#hG`?)O!3N6f*sLV70)DSpGM-qRP4Yj z`z4rdeG20rzM`$hX%AK(_)=u0OPJ5G02wsr{3smRM(kf*epRKCVd0VifkBB**=E3~ zNCU72(0nXlqZ7Kd58y1Rxw=q)$oaNZ$CjIp!5RrI09y(oEHHmHo>-|tU?-qOyiCnD z8z^SZIlyQ#LF=EnvNJK7W9R6kO<7?--T&%+`aL}qgcC~*e>ndfsF(9rW1BF{{IREC4=CJrl!#u*{Mw9PDQZj9N*DBsKWOV}M~ks6h?o z3U}>GRgq{D1+UGRkM!>TitjWdQybgF_a29?X8LDsEfoo-7j%0^yn}Zh@3T+MtBGnm zSUtJJ-;N$5t8`;ROQ&T$IFuQ#b2T3A+phl^ZYe}uy;AYfAOdOWy^=iymQMR@_o+Yv zDRobQi5KIYJ_MOjKxHq_Z$`IeA+ye2zr5Myq0I^v5}lOC*ix?W^)-oaN{P>f)D6_< zu-5+KNn~sGAgX!o2H(s;G%*49{YCenYUZr0PSS}i_Ye>_%<0T8Bdz!@f zsuG=3JcT%QJ!<1V1okS)37e8t2wOAlyD2mzwYYEs17JZHPD_(O+{lDmQ?#o>QQp=aPg&&`#!eX!>)JF{yh0BkaqoYFUgNs z_v(jxHK$DOITw1K`KYv_4uV!gvZ7pbv|d;M%vQ(R`izqZ7XZ~=TfM+z*R4k{dsx`~ z)SsR^qhVz9z-pj=_E)I!Qz~t9sGHt+`MW^Ikm8rW+rG>dN%$SUUavSe!+jCKp5;#z zUfAKkGY`$C-R@3!uR90Ljy!TaO?`phHRawV)ftN=@a;t6J5=HmRF9>xu*_DMnX#y` zW|en5*lRf!_lHew244ma-jnJS2RZ2tjgt=qtzyVqlgIj~)=8Wye+sx%OovUtX3MxD zW9Vp~>utchUI59vMuCHw-0Jhz3YeA&5r@e?T*D z5sAOmMyow-7_nAZ@aZbeo!M9L(8NT0tpAF!!|VgrD)h-H&(dX#)eeC|6G1C5f)}19 zvEmy^96Gk*10RR<^KZSm0-0_8prsW&oy)WU)~5_c58<4Tfh92+Pkz$E5~TcgbewW{47#CSbrnUZQYeCjEA4o z@=2b$4?3=A;zVl)WrHhF%T*vJ1=4GuO+0vCv51xYL)?C0d}s4udAw5-6y>Kxp0rlW zM4Nl4#8me?QB9^<%HK7+pS`KBA*yucgG+#_xbKQ?jESPvTRs@J1P-r!kgENZSUmc0 zWyS4Nx02RC*N*O|Ha^1ljisZ8q1>1!4@%O|;h~n?#knHU zjDl%3!%)6SklSZF4mPz^9eu}&E5gmWp08eyaE)1~>qfA;meZ2!XVK)%B`J5k8xM4V zU%&gwF6JKB(!jM6tkpOxE1Po}atYi%rPa_voG~W)dOQIw17C~rO@lw*d2!{DpOVu> zZ5+O-J3c>aHDC>rbTto^1A4D8Rf;o35KljaL5Q!Tbk03fOKVZ47z}#>=OzW3+9!Q? zgEq`f|CGqN7_T0!q*hY_qs{aR{HKcegdPymgA0czty_9TFuoKqM`ATHN- zQ5Um`iw~S>9R%`$vNi0;j-t@qt1gAfwt+Zrgb;k%5&vHu$3SeKxbep)yMKmnZXTy2 z@jug<4uFRiywu6UHP1UPj|eMqvv@lL#D-e@3s*PByoh2_BKF5l*I{{0f110a`tiXp;jE)?pOh^les>M49t ztT5|ivmQB>`GNPszn$v`E#|)fdkd|>% z(FIL%#@MhFrNyAJpT2UQm>T!!j~t1wX!obj-23%df$2frMk&STYF*ISse&7`2b2_3 z4Sw-cha={7b20x)ZV&EMaTVy}IfW=&4?$fx-$lU&I2(f)@sP->ki57~FQ*#pz>eQD z;TE#R#RRE})5CCDDizm3#NB1ST20U|CM6DLu%AG7D1nHGYkXXhe~Ydnn0iIsWu+;c zspJy%pL`l(>|HhxY+T5S0ma+&LotM~N6WaVaCxF|32a5`rj$`1{buT&rlI^St&7(b zLW*^0BWSCM&Tf3>cfRl5?HnU@_q0KzXwXJ_-*7c)U9}ny6R8ztP zD*URekad^WMMhkVbMNZ&&La4hhN&&kHAs2)B3@ibEbg;=Lor_8JxYz#(EAvzE&P12 z@{Ew<0`aH(dVJGzYc-<5eIDREOT*S%+r3%jv!#uUk%f$=>$}lpG>fuby8O2SAl|?_ z-V5xy_M1m2u98d7#BH<45TMdz~Sif@(=Cr*5*oPR#> zN?0F65q4sNY1m;mCg6JYhfWP>rC$C1Fv&#I7IAm63vnm3e~H>9J!~hjwtn*g9GHRg z$mO{%IdO#)yNO)#fVQEx4IwB86Ojq`@q6%!H*9)prgrc+(spbKpWyV6^WR+^tb)RN z)SFCoF_jc9nwmf&Rb6<~Nf7L)M@H<@MshY0E8=He{<^$5G|YbReewrKoc-a-U%f_Jiqh^3`qczR7QuUi+sYgN^LVjmd|k|r?woHwbf-4*#FMha z=MQX|m4E+>`MlhF_RZ0kko7CLY%U!yTMN3ZLsriF*N8$~uZ8v+>C5v&;0fdu13R#=Nxr;<&aQ`{f z{`#6PnFMY`m~hsCX=Yxu@b$yIYxAKWc?WX;9>@cNW_Z2jcAqkZV@coK8bv>!~1>0Hvy_*oa4 zjCJ<;$#49BhR!rDsk9B_&pC%(Ku}av)FW<)OXNyccEBZEGBhhID>O4TGg2#SnwbNd zxs_;E+KNkRR%T{pJ1&*khGu0=GlOPj>)0}lHI47fCqBXtp7U_u*L_|8Tf3+vfE1M% z-d#Lpyx1Cuv;JZ|%Pn>(1mqo%~?ggj7=7e_HM(4D_NU0Y^*<~@N*|c;C z!aB9}ZVmh=W=)k;e;Sx-pubt_#5TLIobNJl7A+yIzh!=F2T6}JjDlQW23chMYq?#Q z8$sjj2hkx$l{_~AgK%Ah0#D}RMPv0Ir$b-zKE^cac8KpLEs&YS;|SX22T81F-Hn8F zWWxKiPO5nYz^xIt`XX&+2K&__#YPi0l z*|fvoay`&FU+r6YmP5|FDWrK{yj%R|!KZPdNfqKEq-_7!p*wM%g)|K#rxnbZ>4H-C6>4-V__JlKk-i!pvcRU*S&IgwqdTAuAg7a0m|3xu$s)JuNLWuK=`lAPZfG_EeBmnfP{gyRuw8g^B7lBh2&(qYq9 zT?oTOs8^_Htu~znGwI{#cEFrO^C|ItBZvhTK!pZOgb^*ghj8Hbc@`i9{&d5O(ptHh)ws?HXbpq@B; zvU%BdD7vD)H~@?T0^J_K1)KowE+}O0z^ekIzn6iV%YmWGUz}!up!l@f?dFMTZzlzR zh?7o@#D9)DZ;YY=%Vt0%PDCFDZII=&#_1bN*nLLA2K`G>_|0R+R{)_#QIcr1=FwOx zYQt@#&9loEZv(?#I)7&M0Ni(8MU2LE#+nOcZZJU@=$tUMU`GH8cbK}hV5ej!B`)6G3RNuFBJ679YJ(|dyVQrF_>}O# z>-Ta~!%YQ4M-mraBhT%>|sOUWL%NP>b)Vt(YRxsco^{u#ZSh zFA^1)a3M2sa~A_MKs;PaN+>E|$DBJ=^X~R?@+~WO+fi>5!U)YB*Dd94R+}ockv4rG zVkrCGw&izRrk$oqPkd|5{}KlWJ*c_F$NZ>waKSgvTVyeg@U|&nvo?w4FB1?Wemo&? zEj(qC?9)ftrsb2X6$ySo4d>#Bx7e0!ufwU);Acx89otn4U(lkgv%l$#dKD7L!@`;R zQj=%Tj|Mj5uoq5pB*fw{YsfS+0Yd^N3N1B-P)NM)M|AEHt1Xd$Ijp-8Tn`N_ClGcg zcaa7~aZExJ;q!ESN-l5mP^G~(3S4fdb>nDV8D0X+1l7TTMROOgt1gr=a00aqVMERM zJgPR{#~PnML3I=?-;G;{6~u!^pJ|K+es=4AhE3;A<4*%yg&G2m4D8&^QCza`kAvvk)JffMTdkoL#DxU znA)Hth-LYAaFoE~U|^7@BR1>->9f}Y3Red={=GlrgNa#%;gS11YH5@pgy$gcW29`U zqBp!q^ttfrzFduLACDqc^ciatIl_OYC54<;G0=~t^3Y_;k~WE-Q|dhHZK$KY2qUCU zYwlpW?O((y|vJixCUPmHUW`x0TR9S3}hpMTDRk+^vPw3 ze3nt(80D^2 z+Gud=v&JwW>y%N>#o>b`usc}Yk&@_Ek&KHkhE#xT1gCZYqDz{@R<2Wb0l))T?tNY_@%R1>t-EV zf4IfY0ZtHx9U3{fCNvOw`2K57NCUS0!O!brm)pvJNZ5zLVO;*6ZCLx}U{Rx>DYlpf zFA>JRW`Pm68)x5usSLrmeP#6Z!}a^LAMqZ6fYlR4`0R9BC`M@!jVP&(v%O$wuAeEzRu5R|T9v@VT6tt|e#KKcU--U0VMsX5tyB%$`jO4Bh5AQ#pKndfn z0{6xMK>q-Ou*Cf9%CZ=6?{P|RXs#QuG=SCykb1rRAc43v-6fARHj_z(&B=8eVDPXZNB4#)d8g=xKejj_S5$ynMoWcU?B85FW zMv=cLASJZP_j1wntmol*X>&tg0UpLYfe~qrnR27W3?*8&l6*P~uC-Hs&QM(g5BZJ)G-&49tR&joS+>meEzhsc^ZL9X(5n;Kow2F$jPyr z9fllA2_|8tlJKIXJ-E%CGu7(?z9K0c>{MJDvJ3+;7!wDex^?3ZM;1#vWn&a*A{Ee= zg5Ynx&Ru6(0<_!qXEbBQmzauG0+jis2@yW`w?T;%rdVI#@Gm1UbPdbRIW#pCkl10O zzs_~u4oz(8;$g%SkuS5y3Gey}<%*#2B6t^Z^M{mkp9<#8dQNznKV2M3h;N2q9q}1f z8iT_V(_jC%yP#v|bTRQA2k%nQ{98RNXI{_3@>LzZrST#2CkjeRnoGCl|Kf9f!O;7# z8TshkCf{btgxrabV5KknqT)gvzkMjCl=;3(vpfC%BtiET)|m<+ikiqUS{^EnZd1O4 znL20c#-%hAO7L;~tt`jZl$1ujZ>!kD9PmEJI+OZC)MX4Up_lcZa+^?B8>YNS6LS*kFmDzf*+L6$ecER;2d=BL#of?e!) zGUS6{_Omqq7#$-w|J=F!h>ISCXK!dTZXinywws0`TKlsCFy@fKmtwdPhf16akp4Bh zN!aeEe}@J2JeHno?5Jf6oikL{ZUc&NILskN^2nk zXRZzTOSvm??Bh^;?mr)79&Tl!`y7^K2Df^3U4HwN85Z4fa_;Qaf86tv!)RXkbsft& z9v_bpdB%^9m*={5K3h|6FbP)Xmq*_0R=Npo@SkD)!lL=}2Uzrth1UmdsT)C@z`QcNC8>hjUfYRXU9QisqT53%C@2BhkNqHTA^gf)L zj20Sbna0yGVkQ`&2^Hi*oNTFJHX7(JfvL79!S&5n?2(fzN+=lxQXOU10?OzK{bWS^ zT1>c32M1%{Th5BgU!cJTWLLUtDxZv_o*FG>a3d2WkyCY~A_YmPpd{$vcoM?bQud+* zVRF1FR^S}CgN`FDH||ajWT6Bf8Ct(oOOk7;`vKh%FsmMbWR$=!Xprd`OEKb8+=1f= zCUM%H;XuW9K8(|=VvRRJmr6O2rI5q*V`%M`7bMCwnO}Ed{_>|+|5yzkA zQnR>1%>U*uoWC1Sn-1p(aqlmjbdtyQGXGu#4WrgiaYjE14H>TSv(Jsz&%Z;4=wcg%T1PwfaM|I>%bZAO?5(v_powZ;NJ<0=*8Rs+3-Z*zJpOCRwL5wVOml2ZKcX%?l&QG) zvdk8x|-c5e+v@5v$tgeMOl{A+JY^4b!` zp?k%od*L*7o{17C#!udkPhKOheo~^P$P(5>qc9z_w!qHL5E5Hsh=ujUMkr&oZSvoI z!cvqKkAwbK3^C~qYVeB3AXb;zmsb>JJfd}^(}2yF5k>=h-K!h^mqT+GjDb)*OdVC7 zq@|{=gY69f>mhysQ=vbt0<`N;FooDCVa2g2gE$1Kn%qiBJpP!TgM39%?#c2qar3e=J&c6o)j=IyuLLQj8vbrK4QG03^RTl z5O=Ivb*#W)rEESnfM-+uJ&~O7Rnul^Gbeh9n%2Ek6Nla3K8q3ZSyi7uIt=Ys=_{hfD)#?ifXD49(a^9^kjitcfh7+AuFgWb#3tF zkcC*iY1ry{jYTE#h=VCHFcd_@U=%JjD62RJz%wGN^wpMA_!w_!*bB%xE8fLBj}-8o-nYV{WZqlz6;Uu_pMwVemJM zc~1ci;#bxZ`GM4R^_1}a@=MAlM1f@;clupikJK41KLn=2kvE-DjB1oXk}yC01V{c1 z#qa<4v^2uQ=h{N;8^3t*_^Zr&2O|rRbGx4qQ$oFY7v>I}o72%lUuzPW@a%HvoB1CU zj2nfWnWZW&cJ@~Kbj_Q9j{M*m*JmczkyEvdrCRv%BVK{T##TpE7{rqsd}q+4bCrJI zcfpsNqQ6n*zhumu?9NJ5l&olLJotE>trLIY(YXUnp%~fEW-u_=#V+i8XzESh!5nzf znA6BO_cT=Io|&6SDbo1u8uIf`Em-jV{Zo5jobYZTj%hPgFVZ!{X4TT|;p6g$zrq15 zicmpI75}^*^O7n+s*a*Ty#U;Um<9xS8sy2LRDQT~qEY~jGOVqcX_Lz5H&lM3OQSU# zRuvLi0cmuo=Hm^`o)DjVlX#q(t`Qf?xY_Blh9Emyu1=vq%qWiWJdpzHDYWGA6*_&_ z>jyRt&hdNNp5ufj^B|ExwV*V8_+9H1 z+^iuTMj@jpN}+HK=+t_m5kfgEN{Q&fY{gX$YacTw*rYx|;y2p?NwHX1=b< z82A0^CF%lzCdc4+UMN5Wo_Y*s0;SZ(*6cdChnOx6 zatU$0KrXdT2RdW6_8>l#Wyb|aHvxg+#b&BArrURoji*kP66GSO_D}zaZW^COlall8 zQBdg>u;gGm-)Iwt8NisY>2U{L)Gyn&W*`iT=s4_FFoM$%juwXs_%&(U{#P5s+d%1f z3?Nf6Fl=qHwp#u(OHY1ZmL^M2c}}!eq~`>O$yaa#v^`7T1#j{Q?j@#4yo|oxZP=plfM-C;;N>e4*cs7D5}HE- z%Z(_DAAo7To)Z#w=ASs>0iCA_p|PX265?>W9wPIKD1Jypx@MV)T?6{O#4A$^#V2)4 zhMUYI{QVokFBk?yG()x$rg<71>j_$=jv#8#to+lqdT}6d@y1;@=>BsBG%vZL+j=#j zE}Uj3`q$wn*S$I&VP4E5snT}r1_mDaYFEJSgxW$F>PWg}=h<;uMVl9&jQsibsOAO% zcGgNhIj2$TSo2b&qvHULqdCHj)j{3>CVmZV9#F<0L2t}x0a!yOv}SgPFh0Bc%ajW& zHxSGzCE)6iu{YNeGboZG?nd<}ek;s20bjp26ae0Kxj=w@nT6$Ff4u|Yq~RvU3w-q< zbOvE^r;>wuEZ#`-qi{z~JT-S+cD65Yl@QBN)B}r@cbmSa^Q%fedk_`3#Y_sI#gxOZs%K;cbkgknO#VXzaVCl5HcB#-5P)W&>uc)iO(X3NF`cu*PUgplL(bofcOYB)e|@SNPpC zGJl27E>g~ToojfUZfjjV5~(w?LA4>U{4tnS1#U7FH$Mn<)LW}9p}afkdCu5)V?Y~I zGere+sFk34jXLDXIjMje z<$I2&7|3B*2?Pv-&Y~n*PIU;)Q%Q=BLZLW-5Rw%{_y+;;>jQCoIV8zJi6mKw$)>Ox zuaCC?_8K>d$5fzXyI6hS`!53YT}A{~TIx zhc+nWi|gjq-Jy{Koo`k?PF1Zv_LE@xp*VnkH%&X_g7+`& z?%EAyPCOs$9jgYYduVc^B3 zH}V;~jpDtmg^hho=M$ejx35PkJLd80ov^Lj0+r3b);$XMd3pV71UJ);qGi2RzqF|!m_H*a*xBdEy zc?a=LL^KJ>D~)0T@M}Q$zT-2n;fJNLWi^&KWziOTxp78H31*%#P$gj>3>Yh@{C3AH zeM$JUJO2FK&*X|>o7`~tx0g3BTeun}g~E`Lp_SDO)8(NbYd1{P=C10t|F_59%|LQ) zn!UnrW>dF4#-glak*5}*Mg%PL7de5bBmI~ld14UxG3<^88!c|@Z$kHA2VbK%h6SB` z$hQIR9@!FGWv0MjA$+xBUxpZ`yga zL+$`(Yi%w#Nn7hkUb>WYK)9;%vcq{tNMUk?6C5M_lXxPcUF4Vj*!Ax zaQ6G z2Hx{Z4=b6!bJfa#fnRA1SyL!QyyTtaT=EetmOo+VzXt1l_MzxVU8IS^4&CG zHtlpkfC!>$NlQdNa{x5Q3OYC7GQSBDD9J<5l4~v0!5{d!jhe2WCi-?O^`frKReWxEGNkH)55eKt}KEw8=~d~ko9nx90a#P zB;6u94%Jk;3pe{yVj+P7K&dl~c)l3@sA?rSM@ukrx+pk^n9CFBk$-QHTy*OJy!kS*jAwL*vmwiG3)|cI zZf%?ulEge8=;{bR>{E)1lAn2<`=dJKeU}}XyvQBrnaHzUz6j?|-+X!0>Cj+kbei!^ zu{|Ts$EyTS;$7{uK|*ct(OS>Ox?s`}s4@?c_K@(8PrfT+t*+W17vJKG8FOa_m%H4njC?HIMEdT63FDJlE zOUlCiA4;vlScs4%4*O`esV7owj}TS2s+j! z4vWnE6`pY=-r80qd$QZpC9m|)(~ey$ z1G|Io-8L05r`f0(lC$)1lxPMjT?OPvz9S%5Z+i&NX3}QZU;?>NI{zG;ttI9dNO^K9 zFKE|QpOT1Uq%~I1g-<))osFa#T!?KhQH}}rZG1c3q7x4xh7J-Ms%1&J9@$bCX|$uX zx*-V=2#Fi!3f#3>c6ubz)ux8PyRmS3;Q3`~&h7IGT)T^s|0>+W1vYXbAn8eNZJsv~ zo1D3`-)-2eNS=K~1$qbEOS=VOal4-QG(-e?HjW6>=6W6&Lp`U98KC6?|gG~LER0>tWLMZP0PB-t zOTVvnSP)yGjX{&l&G68kTUjhvDk3g_klMS+uXdGx3>zls&c+Os%)Q|!%dM_& zneXzCxsU*BcEfQZg0CWJ+6GH@tVN#Ugar=G-45GeS~K8wX&+!~kn&HE3y%1U2wZ@C ztAeAn^S`kzMq1!S3Zh6ym_1?PQ^n)2k(ITJ3oqC*MUcO4sW2tNP9D#%<1meBBO|t1 zUcl{MzjAy1?llMJe_GfT3&9g$SzXQ^Y6NEp6dvXwd%7w}m+eB3?g$bM>SgK={X*kUQOGZwN5gl2&7oWAm>yK2RRK^uO!NNJ?@L zW~9V`S^M(a|DJKBPI`4PbVg_B%mC{Yb?a7t*QPt`^9$hBs7xsG%HQ06Cf0gJQ)22q zPhfz**Fw)Dqdr5^U%H@L7}9tid1ULdNRW|ov|U|)=FA+&P0m5Ds`gfSUdiM! zl!{X)gjZ(qe%i?I|75qWfMfRD%UiLTEOdS8ksr5f z&D)g+<{lLQ=x1{`bIp>v$GM}=y?bhwP(@P^Qw-d-q*xK)c@lCaknPYFQv*kDa^>vf zOA9v�uC_G`RdTE5!-QJZo!VaTUebY~{51yV*e-kdcDbMQ_Y}PXm_%ExloQ_?!! zdY8FSFA1J4hpue^C*vc3ZSa1$5%L4jA{K0}UudiIV<>u0Z72A?gMU8b@moUPlKqP= z<%x4l;GBQK;+(v7@*C#*mr8><)m(3*Pd?{|4jwBW?Yvfj5*GQfhOn3y9zHiJ!v*dUIz$`3rN=bz2tVVYxtj-16S`vl8w=iz>twy{b#`=JA7t#&5R6lCJY4&uz9;s6?r4@<@0qXL z_`~h?!QD&i?1X(ebARm!&F|iFrFl`SSv><2t`BuORm5qD^;)s@k9?RflR=j%LyFYi z;gu1iVZ27iqw{ynH9K>i*BHFM4VwPcKC=6If8E`up0c;2_^(MJ{;iI0Vm0p`>X~xL zRYb|w%$_n=Li}|Qa;CyYB;w7>Tx|xzP+HA|Nv()*;!6ePq2tCf(a0ohgoKk*>6u7DGlKthz@#IttMta6+nHKMHF zHv?)k9Jdqkm7}ZM5C;XzO_$$u%fqtYql%4eISGZ6^kIz1gH?Qi)(k zXKu$F1MQzl4?bdcO6>i1*djSJ9W~mbKu}tUr(!+E*jL4H<*7tWOIK@Z8=@zXZGE}r_c>rWc;IZ zvc9t;L#t-CElIwf@2s^`hJtj1@c2=O0M5txRS}tEJFh2i>Lo^e8Q2y3Me%H4q%3#i zIq6ERBOE52KKHb9wODX14m)Qi{QBg{!cpN|{<-e@a>KBp??EVag0xwCXrFVQsmA8r zM;HGgePP5w3q|zUk{3N&R_(tUxuFtR>L7(L0dps${^2`!s_>=W7F0}_od3&YvF5w` zM1^2^Gf04XQ?~_R1c{v7`U)}z#;t385QUGgYP|Tvyoo@<2{jZaA}#@D4svO)>uhra z<)@U_!l5Q6Y_6DE0(iDV}m1mAj^bxK*2nW2IGDk^hqtCKhRNxX6c~hV;q{+Sp;wr z5e70c&|*o%l^)({qv}cZ@!sp^Xz;#S@g(Miapi_0C?A4Zf;=6@xMPO!xQq1KFU3rb z4xA>oEI^IhnL25WU+NApkLo!hQ5T8no2l*w3QWa#y8E4q#NnVFUpcc>IQc}J-iYZ` zKt!_`BcO!L6r+A8`aQ~yk#xUJv|MQ2e@{Y$#J5%7UjC$*LvHbHSCQ?fMCe@NU>YRr z;9oL&w#xExe|oCMpr*tv)2gA-DBruV{E`-@EQ-|CL$qcE#oLn$7f>!oegvlg;|oAw z$_uN|U#&-B^Eg?Tn(nMa^Xc?XOs!(}VTHDtnvs$jsRqA%N``2Z6u(~7P0?g=lNUO* zY9uf}(GFhd#FUR%Gf9GwosK^E0~{7lXw`uWn~&+fWtk-3*FosoZ9@nNZhNKpTJPLEsKi zpzS6BQ4o2qI}6^Si4)hVXwoQ!XwTG%xN;%vs(ahWq~lry=EjZ{SlZj}gz3QvS^3W` z1(>JaMw|peq+z{C%V3qeALDNmLc~}c;0$_RJ#{e4QU@x{5^(jPE5O%W1AX$4$<9j6 z<#E4Zg?fZsvoKNG z5<}tST;9-^a>AxJ7_(9iZKA3YO|%PuSz99AKxcZjaQVBRr|;H**E&}&IcSl#rvhsPGwMv{EE-K-Rp;C(rR z6#Q-S=ZUVLe(1IpZl(F}ST_1LfZlq&^k}R&liDZL#nWl`v z_R}HqhlP-1JnF)Y#;;y}hy?KJ81L%9B_Xiz7pQiMhapUx=ld|KaCrV} z9@3h_y{2gr#`<7+_nh{SWM`yi*^w#sEc}D)#!idxJ68&H>eq%!m?P8V=kC+K9!{CD zatke*R#oVF`6wi8OI;|(F8}Xm{Aag>KJu~!7UZ(NX#r07Cz{fUsMHY`!6nbQBA1VC zA7+l#rXJXf__42$(<&}JNipbazr3sT2;36+=3=f^Pgfw0QC$F*JCi$bk5k2EkhB0M zVQ44=(=8LQ9LO%s+h}osP!(h`8h>;4-&47RMcJGhq>s_Slp)CEs0tL+5~Fq&BRl|o z$I?9{Pfei!)PJ{>`(t2@Ro;qJk15ze>_agVLJv0+t?N z*4WOtPfw|Ipf+25Gllj;7QQG>bmAmP!l1wC4kL7MEV5$b!)#4UiDusFk} z#>ynu0Ix%*3V6fDg8!5W(Xvx5Q7TC$ceIGmI2nspAiPj)<-7z5MJF;}b;ywz42rKN z=mjR7x~pGbJXr*9hDf~H^umDbea~kZ2L=ZH=I_rz?K(iBiZRu*3xE(EZw!g7fGLKjFfcl9FMO^FFmV4h)565b z;aQ)zT5zjQK+iLl_hmR;?+xbU=JP!xAibnsX(8bhDNiRkjHBe}cbpwl^`?Ih@oqlYhB z5Yc7xq?IGnHC-emQLQqYDAkJX@B8a_0_9vfXqkVi^gi^9T zlhxJG?eV)au5d&(5=w~@vCbQS5#=an#`c!V4cnk6S8zBDz!n(F$Q=e(FAu)Ibx2FN z^JX{6nWd(4EXr+k+52mf|J?4WBbW}<&ed#etjc3i96CYh)IHH}bJ@LDcm9A(Su0_y z9f}<@0ktc6ZzO~nqD`YD>e&x%=mR zIJ9}-?c`}>_)IE{v#$1^C;;B5pk{6Vs)h3%h@m@VOqvxXiZR_WEd$b5W(vYUY-GlBK|>CO@h=AFW0l#5hIl zkA3v|vjjg;gdYg0qIJUFo$8+B(7XLAYi$*$nGmQkUh}}VO|_%>0D?g8M>)rB`HL?E z@kMG=)cVY~=6&OUi{t5(g1|IzvpH}LpOe;A&zb=G38()pb+Tx#rzq(#rI96t7!)ce z;gXpGT9R#bek!uqG68WYLS|K~eJ1jK<$_JOiX^@eFHLNLa7{oyQ|V%)#x8+fVvnhK zLMIdXzHPa!Dh_Okw{rxzu_9_4#XmmS`_KbdM3x~eJaq{r@ze1}rN`xF9weW6N=1FS zistQHw!HrEOd~`0M+Nc(Pft_zqGb<~Onng0uV>wsP_OCRD~oou=rt!NOFnabIWj22 zrlP$2mY=Bd9A5Gy!Gd*CHzNGUY*9e zHqrRHPt%)*pOr#XYc|PoZ@Jn2h8DBhno1;Z=kam8X*2*#s5lc~R#Cpc2AK-3Y3$2B z^aT>X=-Qv0esIG;|I1;UDvD!O=GCJZ#2g09BC1saZAY1THsUOQ<#DeZ@4>JRm_~Z(Q0>^y; zy9aQIy8fjPwB|JaU`Y`b+v!(~uP=7ml)igFpn8#~@!MiW>oi%jq1U%M-&diE?5GId zddOd03(2aYWx%nGV9RxG+^wP&85^c5cM1WML5mL7FM9)QHB{#kLh}K>ZKp-^?pBXZ z*aYqP{$0?g2alGi{F=RwZnZNu>sH61fL^?@3l}e)s;BCEs70VDZzmNy%Vbsmv;;S$ zR8s)g$DKQeQXmEDISN&Fx{6ZWeth6kxzdUzP?=8_(Iz2#GxdF2HEXv~%2dth<5s5O zdRybdOz7tb)fOB4<`z!ZmwL(W-hH|ogCd;Y27O2{F-d)quTc(w^w7L2| zKX4Z+|4?Jg(3Fy@%@BoXD@J&dU`kcl5tCGX3DoyaI}-2umIvwS^Sa&S`)ejndyxNy!@jx@Y8s~EgY zn4EJ;Sk@PLyM)iAQ%{VRo?uKBQ+wVNAFVAv-dV#y{SFu{4b4L{223|Yb-VOt2U#|3 zxvi@LI+0|0AyZ{BUVM^L+uc@tH@DdScn!?8dPu9?y6TuAmVUxhbut0}7^^`haL%O5 zn9}Yes)|@`$C0Gpq3}FLvJrM5DbSdcWDdL8%z|C zTAWYtsN>)&zc#=-;HFkuIa1VSimFZ2>GT#HRaDts1fsv=!h)ovWt5+e0k+cWN*>*~ zwN&;H=NRsfBrLwG1gKaY)OS=reAK&`igc>hE!-6lkuSm^`ITLl4J94T=Q~Bk|0pc= z4#rO8L1s%=?H)lIod zlLR)-2KqGkj*WoGKCq|l)Ul6$BoN^-ZJ8 zz=wx*{%x!jk>zPI|sIKG24gJJ@47&i&Q?JfVi~@6wO(*8ny?j^+~O!eYMPLl|yrD?=om8 zp>2`^8dpCZ9xa&*!66q92i96$x9vcyT^wZ%Ws@#Ix3}7fbBbbg>Z(Jn8`v^O6L7jD zZ=YjEE35;i!_&XMHliGWDSmOmw^qkNdCX~*UFSn}P5r=>>U67U-u+rGAG zR@#R*CQ4HK2jHT%dWwfFHDR>Ve(Q#Pzn`^^JPmMu2tDgW?aO{rAcrS7#Y*~aS|x&L)yr^CrxmDnLt zLO#oSj&Z3J)F(NHgJTBTze^L2Rh#X7Xa9brfA>ZEhmq7SDfD>eu9HuWB_{9G>4SR= zHc#YGi~Y|) z5D-KjJ$hh~R_h*%2gfe1s{M1I?F=i<88(Rjy=ZUoZP7D`{ZCut25VqTea_JL1Vxms zxXAWDwrkShb^qL>9E*35T;+$IDa-agkyi|(r;pc^e0svlU)v9-Rje(2IQ7k)EOm3m zR2DcmpMj742RRNywX=~~W59e8i2nsHZNQO$o1pswtS|VbJF{)@!ev(N(TthjoY@&w zPq!~ON2}AGdC}WC?_|#yxp)C~kd@RxJS~BBX&|uAJEiu(eG#zULIYyRPQPsN6VBD` z1h*iIKd~z}$0ZgCQ^x8M3Fi{562KInQ??3F;!jj-ZU1+17l(45rpJ?1hyJ|e`)>>G z!Qc*0t}JN8P4tX%fS?rZgHHUH5;rF8;%d~3t#obY2q@WF0S3~9r=3*gk3DlcD z94$=UscYXKRabm-%e&)kYMUg>etPZ0f)exbb3;rOi*J9xF$T&x$=G5?z2z$^JRbV@ zG^3AY=$JD_?3O$Ozr3F@C@X*5;m`ly&|zr*f5W9nE>=DFn_KcI_B^eN*n1D9kKuMCn;oou-Jx1`6zQ+c=V51%x%3vRA@dsM;VJo7)? zDEu^)P@>zk*kTgbr_tAT<5?+vVj6>z<@iqCOOfwda5Me-XZ*6h$$lFB9-}7Webl84 zo`cfPJf1?+gF%`3VaGRbr}y$6COFPl3k}fD=JzxsJm@uUtVs?H4!hP;=&`t{=*w+0 zT8l8H-m>B|tJ31z?dL@T3JTFCfQ_NbvA>5WQn@SD6oqPrC^Y+Dbk5}`hi4Sqnb&R| z+E=ChbF(AG5fB`?YL&H$MybkrUT^tn+JB)OSsKS&{N1q+UXpC(W=EEa>rBQ$#)V?# zT9u!EKU4qPf^R?=x5pzd$aXl)yt}?}1H^6qDM>>z!^@y0^&9=|Kgz+Bk;oR|41$Ab z-4^@puQXTd`a7#>9FdApl%6|X^z-zh@act?T-;n=mU;tq%$(eE8Mp3K?@~}y(;;s= z=eSHouqzy|`-96KjB*bD{YweWB zwbnP*?zYypzw>*&`u{Yqx#vFjeSNO$eZ4#hzU8+B+%^|+?ZBh>LQg^!Mm#aAi(c!G z>mmc0xIJlI%%LNq6NlceTj6oecKCM3$6NjW+lFlYwAg%X2XR{CdB^7a{o>c@&we?; zZXz3+dlgkKgHYMOlUE&{NY}CcLI?v~pc}T)$5fnIu62yAi|sriXVt|6h|fpoN_2xX zh(Uy|(OIJ@CyXM2l%h@#HzkPhai97o(gKy1`upUEtG(6N4hmhkTqO8|tqABUkMd9s8gPy>><0=`5{5L_yg@Bg z(Lk%0DN=HoHHr=fU`jlsAu7=8l*n{>AAE_S+07-?S&VN%ns$j#y)%|jSHr0YBuP{cZGh4GfdP%tw}FAeCZ8G=nQ zk15FYImtlS-|YJa>*!%nQBrs4CQFD2t6MPJS8q_M8K@mB*#%_0hoIZ4edo!G_|<*aWoid=jkIg z$#(BuGV1c0T8OV2{a+kMTXegHZ*i9FOS_(v-YIv~0*v(vJ0oF`-N+|~v~`FgGnPT! zij}X>!yfPQq3<_YA?Vq=yd%|KG2uB_R?4Uilf;`lF-KNYk5Aj7s@*>|P8uW#8QLSs zx+6Mb>sR~Rsjbhz z(CO5^hxqACG7fS@O!R2&kD^hVa2|RfQarQCZi8nk99>$~6B3WJ9S%Jc08(+0IT%2D zs<-}lRcU(RS2b?_AT4zMyKOrm>FfFON$=uSbQxxu1#Mp3AzK+2c z?qKN?vD|s3qS+{e$;l!naH+dy@G^gWge#1MP_!+&P<4($986gUKkzD^VM9IQn~a<} zyg7l1Q8Ui5EwT5T0~uO+Cjg=YwY(skkq@$G&$J}-J1vHw_Y5bi1rn7ItQ}VayAavO zl4#`hcYr(e%v)x6aaqcy$*N3{I1>UPZvfh17iAP9!tqG`L5N~7_pwZY zi`Bq2JxO=h(1V`JR$l&j0Kyrx#a9IdF@QZ2^N-qLu}Qbj^c|DrvsEIBwxrC>Q#&bH zq2*4ZPl$a&mS%Z>?Msv!2#oHjz$9~}k>olaI{wELOf?KL8=4YsNU9A?kNi->p2{Gf zPY>#;t$%##xO3LD2C>j0`sST@5r4PKKU-IVE#5h(bkJ}k^Qj!o{6RED1DppfVBS}N zKbL}5dsE=nL<;b&)i9k_XUd|SXlDcGLGLtR1)RPUOk`?Y`rE2~3%FaQt+h1o<0F(i z8(30|(a5B~7eyRav}4ps-fx9bT6xH zNTcZyhg*ac)6hNFvS8biXR|~WuT8Df41%KA$&gZwomClr7qWi;aO4w^1HT*apUbJZ zsg2!4e@HdLL1POWP5Sasw({qY#UzSgFCT0;JJ5bCUU3!pb!T%ByV0>qSx*koIBRRg z&8+*8UR_%gxM;Nw5z>I4 zBRKP04Zs4;)y6A^=f%wQ`zpWB<+wjt5 z;;k;)<9h;Bp>)U5+fXm~>_a_|cd&j~SYTmmu)~>&(qTMJrUzgBeNU?xO*Wg7e7>`- zLJ#?d@+o^hX%1$zz$m5d%QyZ)i>2mrsp}N+cku?mcd8b>{eE@qOrr^N4Rr~+EOD)x^r$6s7NJ;m969|}uK zzNFAiCJT}wtrNK3qm&qDO33jbTEn!Q z>^XjeKP|fY4kRp&-OS-qa!*J4ui{z^l05X4OSG6~-eI zLALl+t=lC{pj<6Y5KbiRy?@yyec4Lwm~A1XqS=!a z3In97{Yk%DhydD4xJB6TwlWY zG?pIA?7tQn>H5#IHgLUbdGM4&F?0CK67Y(J0x1c{= z|CZLAN6@-`F;6$0=QSnfIBtLr#{YWVrGw)5+lTO@3@WS>Wh+>#uU5X)yut3XVC}|> zEa<_R#Xtmcjl?|`(MK)OQJXB^Dhdoei~nv zke_*K(@|7PBcn083RjRr|Ic9mP*#~bh&@%Y&?zCOgWW1=NlV55Ca`Yb@5B3pO9&8_Pf=GT3CD1KB-u{$rPMJw`@s+a z3M4sd7xT}X>i~vwh#CreOojfz9E%5zDh_Y!@#2bI?8Qc0Z&)8QK~UpCWCDIBqj_+= zZZwAXnBJTOHcAbp?+J1oM^9gwxdPp@V$;LZ1e} z+_?ktDR#IhXkvC~Clbs!wWuV8r@1ZmzAE`Tga7s#k9)AK1@%f{_-H)e3JB6%pCe5o z_*rdg=gT)F0g}r}F*mO!?aVa&F_SlLmnXi>;)4&lv)GUvCo8DzUzKxf9<}ujz29`Q za;z6mx|X4e+bX|t(={FCj`R>F!M)`<+@!v^d+N9$}* zNcc)>)7x(6#>9nMZ@2|~07D+#Hd1IL#~?{Ho~0dr@yGoh#Z}3H^1;Xb=R8bnWZhSi zD77+qU3L$z4Isja=Vt<5`$0_04)mGrzkIOi{h^x{28ivjbtT+HyL0L7bJD>BhmXb= z0MfdTsG_kdn_CN$d9`X{`q8wRQLm{G$xgT8VZf7`hCh!n*ER)(^wrc}J3#_)@=P?`>f@K2z zUs9aB3`Z%%W$G7B1sn}%&>jq*&a5A3!SoBq3#Q)0X~yRwgP6cOmZ#swV%s|1^BB}j zBOL&)VVvrJPB1f2U(bqvuYFk3o2aKD6~$?Jpdzqd)-Jb4UBH0P?@cBoinvTJ79I7xX=j9{TQr`@q1R zhb^%twg!w}=a5zmT7^8m9(cBZ_I=KL`*&9@{48M%syqB$Us932Of3ie_NALD%07;v zGtl`EwTps8z`_R=AdSQ&!&mmUzvR#h1X|cJhZ$TG!y%c?Eq=J7OkI&Oh|=k8w4kn2 z(PwkNJg)YQh?*W>DyxAUt{!dsEv0igQy$DfNdf!_0HF-?G8Uk4=%94IGxLvYG$Q=* z;B<;dfsB^l?lY%YD5k>$$^qTGN)8$eu(8!Ko8(|)JMa11o@^<9>vJKd)WLONW5-?D z0JM1O^0umiK9t$=Z{yh3T1~C`VKU{B2mJvImrMdI!WKZgacbyeJ1LI=o3E?+Q`L!o z^O_H8>Ag)h@5@l_-1y^zS2;<5PZNdObf=u3FU9oLxci<8&72M;qz&A|JLuKf%;8Li zKnE`6C69}LoU7&RSyq!#A&s=&&gApDS;gRJ-eS8a!@(q zLFH=qauqXPj2%2AV?t_^?sD+?tn+x{x`%L_gTA8`dba8K5{NEJ6 zN$qB5;28#_@nEtm!y~+%2JTT2zy3!=C$%R|QU8+vM{Dw6^E$_)374*J|HW@&#Qx<) zZMc@jb30w^?BVGSFLnf((Q%wn;b2yr@Za6znbQ0G35%h42Fk~FJ`hC~EqNJFi`y88 zq|+lz?(tM6soXMr0F|arU9!$k)O)8yoGOoTeG7%kjHH#}*U(ms5)V-~f>(1NXN{Pg zf+y2+ZuQ*9nS&UW>{u560Ke$jJ=1A{I1s>c;QA)EHIU}Et>8z~r_8&yXs_td9Pg8q z=nfxlvB~Vv*Y2xul0gCO9-qLo-@s}Z88CW(e!KI*vC@0Ho~P1d%haBS&hqkZ`VK&A zF8~QkpREz%^6xlSnru-M1BOx5=8&YG&V~6eKlyZM|D|XXo(TkRj}oe3K?}umy|Ov~ z(PkFdItKndG;Zi!L;X25aeqn5Z*H_8AM4?}o|^R>H)n9_HODT;MG6e|A z&3F|6b+bdZ*-7zr0BUlpmA#EWn zlXrSNQ2QugA2|$-Jw;@;i8Z!2$%%rN`ODdFzN<{l_~(4N|3n_e>dM~FqnsY(m@XRx zrkg#f)$Q(pUz7Ft`iqBw%Z@c}{QlYXtOWY9 z*&B`*4x4-otQNU@6|sb6X32$yZB>jgt5fkZ#5ao?CkASnO|<_sz7AEp3|+;>!1h62 zwepMTB>IwuDSqaSI6!5bKFvJy`r`!mD)n#!<@}rTr0Mj3?_<8GYTYb*u_yH2P$kBx zCt{v8EImTGB-zeoMNR||{>x9ZeP3*EPsPaZHz>Gb)OLI)xT#C{c!R*gUmuB9-BO#^%wc_Ng;o`!S4tZJl{-%l{Jk;vEsK!g@xqMFV@_ z2l1^ttKYZ+m}RIdlbtR@XvBEDdvZX>CCAqr*y1i|qAF|K(k~YO?RzWJl-%9@KKR46 zTJ6;*e@W|YHOUZBW8k1=&OvIYv8M^iW#5e@?-=O+N+bCq+@U@ChUuf_Y!&PFyY>HO zsd4z{p`>o(Guh(6>A05W5YvEyc(f+%jFNVc`@%_sD7TUSVD&v2En8Q&WoR+{prWz@ z!w2{Omq^XXxO=OWuFIopJ`(%7+8W+Y_{miNG<%QtidVj-JW%u7@vtSU@?$)|?$@F1P`Kb4P#&t@0X5U!(8ZLGM1;ks^(+)$J9m1Jy64zbt`Zw}}7}*2X znGeK8K=7Dq7_WFLEct$xkx0Kj5eMBU!VwB(Ee&KT3}CG?+o#XL+pEL%%cgu>Jbz9WSR=nI}o_UDAo-)a_ zrxJCPgm>8QbPC<3e3og}L4o_cyyj+A*xNhBea`t7QH{3bM<2qzfd1qaEVBGsYYb$Ih^$TY?EjYmkK&Z-NXi;mCl3O;@?V76-Ob?q>`E;Qn3%h-6L#g zoR(OdM5IPjit+E5qsm-pZh~QYYcWzLbEBy^nM22$M42~90~0?m9#p3idyKrmx_!n^ zdV@$KPerm3;Ild30en^+C`DW=EFcKVqsq9^l{%Um(!thObkEY!Q0j9E<_*aJ!V7#p z49gMsE>z`#A|Xog3*}QTLuxkfVh9f5L%Y%|hw{ z193bxtQIU$q|qdYF&J%{R=R9X(99WFNK+lc_bn-2PNSK2pMCT1=uguEKi0R3Mr7Kn zq#^MZ^peSHF^R>b9YT~$t0lZAkff`-1&jdD2jR8or$8sNtlh-F4yL#dBO;SP_Nq$MkcPSpo` zw+Kux!KJ9})7FH@ZuE_m&J0mEEBcJflh1y&nHffSdSWYnSg$$pMX1+jbXoV)K<^`Y z849&1GDA0vI;KANF7pao>k~R~yIycwnmWi-5|IiR6_Ecin~8_vU?w@fmj&>7Tb}~C6rM)nk@=ZQV7ur;iv`MqM+Ne%2D3^mbo&$ zn)A|R7Mm<lfOc@gx$HQnn>Q)1bZ*3oOwISa7xBbg&L~8qyul*Phb-a@7yxF>i2vi?q(3?hG;- zwrzIf$#lxHeg>OsH^k@a_RHc8Cw-U}w2#4Y#T^dcJhe&158#<=Pvkg*vqRC4OC2(E zmem-jD}xHRkDyh~=RJhkTt|ZH1)BIzN*GYX-nF%MVYGz=A{`gJaZUv7H1Ckqh5!K2 zaCW&3fnrxZjFFbfl067Vsi37o!1c7)oPsD<@9qXwzuZ?T|5iw%N(gGNcqCfx72dQ(JH#$nl0<&@V30i`Q*l{Ro^>06dY zY`pwu1BID$KM`aqzjPth*p)N`e$Mb(eS@+G@>kK zsv`Bsix;u#?|!BJ)zFhBx>ZCq2%q31chMPDew=xIo#}V&rPIbpY;$I3VbUhTx%I$& zPRFYU{77^7Nl;mM%QYu|m&e6+hJ&y!Z@)oj(W+|ODlTNueYqAmo-#T_Ql(Xymc8NjPSvP#*Q`8f zPhr&_*^<-?QFncz)8&F=bKi#w*Z!J#p%4gY0{uCL!qs(Ai(@kmolMqN&Ad$!fy!MM z7n8aLZuCB5CULYX7cpvO;vl8*VaufmeSlzP{{4w3(FV-F;w~~2NtAJEXwp~nb3fZ$ zSM8n4K4Nc_I0ouIZX>oi!rE=W*Zus%E5r#H_pP#j^%XZ;xorE3Z1)F7l{}7UnT!w_ zI->(ITR@GM@-?rj>1OH4wOdlu{BORvC@HZ_8p>8XNo}0a%TVI#c;p}2jC-}H9h1$T z+|)so1iQ+Eo>zHQt0E_u7;N1yNxZsO_zFd*$kx#1ekty2cAqZxcD*SmTJ?~x`wPyk z>s38omB68Q!P)U$P>KALXnU#WIF}Rp*pB3Zx;^$AAa8v=9apwRzLoKbUqfcwmkP0i=7UYKc@qBEW# zQX&VfTL9V98BWep<`Z5cOY42zwOxE~BA`GSndhg_Gj)kN`qLbpat1M;ehwU&D%9c! z)RO^xfoZK0F#+?6t7J?0<6%qe9LI{oV6FuQSw=}XOT>n)G@&reb0ySHDmwbp6b3A? z{B1r7>8Uq6E#Pch8jG{5BEhB8kb4YzGyp0-{U1!5w0-OgmAkEymCmw&9ZLNTGOtDo zDOOHUKV&jfNhpcoj54Cfv;*JfaEIi)g0(bNf9D0~tPWqWTHHZHU+6>^?ELmbf!0Fj zulgj`?{4be9^F&OI+A;~i4ii(q-4MiyU@gnmNv-^3z|pf6ES!QWK&s^RdzUNA@_a@ z%(o*h%nw!~315Tw=@drB;rKXznt6GU`Fa4~f-aeICW*#^N|6QjW#J?o3B03y{llMy zJba#0ac1xTr`y+_526K6Da=HPU2U&Cu2iH=Uf?1^kSi})eiSOLT}C*Vt` zunbCMX0|g|I9dvGx3_!RMa^pW%QiRm$Q#Cd@^Qld_>ss09rn)V_14=h$eZY-!a!5%smJ7pdvzKe%O#t@DSoI_%;9hRQl zo2cW|0M6%~f0kg}PyodW!8$wM`4X_%1j>ROJ%R8otfi;QkAmFEj2k(4m2L*xg5juB z0#JpPZ7VW)6VW$N2D+CkTEc=N?a-ypWe#8^MQ|${Qb^evzb)PqbPIk0*6IrLj zwcjD3FAzP+tFiOn58?yA@Xy*gC6pjvCzy4GrH>%dJ@-8$n^S*7CejN{Wz-vVK0h)O zkjqQbRDp&Dzbio>1WjpswwwQ0@;?~+Ja%3=H*k|YJzSomvU>-vFHB_i6ip$n#jd!G zP79&o8!KYX9565+S)%fX_uMxz_Y&>tjY_>4=7bsrA>OhOh^T~5XIkJJX*q#9m&uH# z;<3{zu!nqh+U}DV%I7qaTs{3J6MS120v*-4_ft7bI?s*p*DrT0PGzu4dgcf5nDsTQ z7mFT%r->&-L!_v4+&)nNxxcp$zqbhO2Sfv;=P`y;!VHfukey*f_es&dJHa&NXeS+C zS_C>BqNPzSgEi9d5)g7y#}`vr5f{hz@PpKjEIx8f=N?(v_01#n^AQot<&sgZ@3Up> zIKra#ddy&*rHKn@B)>aWe-{*;`XCG7&;3RF#Mnc7O1fp%(k9X!V6j{*lM(~<+_)Q? zZ+CBgPNE8hpsVYLro!Ckx01^%L_eCKMEmYk3)UAaukl@h{(>YKVHH57kN+xaS~7Pa zdB{#YV5(bB*12QhSenL%1U_`ALd^>y*)J6|IcSyl(!?j>auP`;x#O;+y>lh}!IMuz z=^kTMQ}kr+<|&6mbcj3wDv6n)APFwncc)FR)II+ELdfrrVDriOD(ZHqGx~G|5L-I= z%%As(d!H|#JKhGFrm}o!#Ki<$mLs3!O^zWydx&B#F!cN+x|HOj0AjDc#Uv-6LYN)B}NPGLc{eSO`ad(34c;l1c@sb);;G>%_mWIihJVZNaXXU z_M`niia-W0r?~NrNhSYN7)Rm*EHP_GzQtP>_|f3#(?>z%KGn+8iL}zyLTkj+inPny z*`VBu$@w{u%pTfOz=`OrH&Jm{WAlm9@%yd==naafH1?T;;Zzc7rKZkhpdcagboqG; zJezpQdZ%LczH}>%EU@Nb(IodixjIDW?7af&jDts?TF5T6M*dlwg>PL^nE%YdU=cMj zoTIe%nnfe9xsDg{?~&yt3q*%bJ*64JO`0BN_-%|M+xMLpETSoHbcgT(>6n`B`2H5( z^^K^KN?U4SQMl3<7Nma~d#ywh=l}WYJW&sEm*hY#0mMTJTJjufYVfKa^JvVL!3K}x z-p=}_Z9)0)N^~>Jtj#ic|D>yYIZ0%q#QmfwJ|H4ex7O#=p*u7j*fG;(54Meq#4`)( z3%(jXx%={XqdY0nTQTsrA27g<_Zz6eXH#dFSKuB=x}kI@OLH%>Gb2cpEEfqNk)&A~;(VI4rK4 z2CXcUbXdeeI+&ffZ!7!2Iypm0X%v;6n(x;p_CWnPuyz$MKPuK;H!IANWE>DDz>ANy zOs-rAG5c9vVxuF4_UkZMcSK4akyu3csJ}Z#GQ&n8A}3-2gVpHpk+tj6Nj4$G@fqRl zP1a}UiJM44Gs&f&HOZNz4~Tv8B_Z}a;Qm8@o0-aJUh)Q7A#g|B3I zL&nZO^*@(LIpu8&s{S`wb*I1SLO@ky2yKk{Z6w%yCSUV?h~`9!P6`TyKhPe25NF#QdibuPhAeDwhIP6@i! z>)X)_tB|mjKms7b{ZarKCA$4qVjxorS~*D+wvQBLSlB(t;lm2fev)rl%O9ac{TAUU zaK5Z|3sQ@os>;pkh5nhMl^Rj(SE9#^@cq;m{gmjHMbys-(kZcv`tzRiyd}DL@~-G9 zGvyxj#fDzaN;%H$#E#P5Bt?YjDowuQ7t8?8V(;_@k%gU0sR@(X@jXdMyoIq9$i9nm zjShZAUMPA76;Arv^#W29)xOpIx}POk3a|x$TOyv``~BuXn%8RQEqTKo-&OcEEckl) zjIdArc4laopM+>?Ux?G12IKN$Q{cB}{R>Yp@WAA9-^_4IOhn@G+3q2hWvM#kY z7U2{$a!iAQg%*T8bT|H`S)6FuG2g;_sKcZNw6$=hYGtAdVAjFa!x+5gV#P-cN`OsX zH2QI*`2M)=CZP*x*5PszVgqQOY=j2CL}Ms)%^8mCwnCq^yd_fw)pl{UeNyy^`caab zM^#ULnf~ff$WtSBdDPnz=E%RlyKUj7pRlmh>-WV^U?%g6QXOXNv zY_{{aYy8vymdAo^kGf@zYsUiMl7E}Oi*7#3$I65L+PNHlpo?f(%bru7R2lE5dJJm| zw(dXPOb)kx>zwnyr@==X%V#r59;HFQE&n<33$&fY2=db)9VD#{m<9HHvob%!9(Rt< zSrDY?dy8Cd-*hkWNeSR~+Qt18cZlRFjy)Np+0!JyLN02*B^;xO<~8oHdg3X&aC|`L zO7=Mn%yhtq@xLtX|5_a0F%BgW;-{eWxP8xE+R@b@?6dGLoDgL(!Xd_?8N_|0FM}PXtsq=ib@|ZId zPCQC2D69{3o;3Rn5b8!`e zdxii{aB$;QXaV`f2#0(M8eUJ@Q!NRsNrEKN3H7t4HS2$HG~xRKLL>Bj2fE{McST5D z__S3wj_{pLbvo~EdcM$AoTvY#(9ZyY#F}`rv%8G%XXO8XPt;KtU;UV!8IG#t1!#`8 znBg*j3-eZ<*==Kz9G#A%mxIpfL#$DHeK^K#4iTTk@vox1wxrt=$4G>eWY zOSrsG!>&|+GOfU5`T0ikb&u)toAsv`^tpaA?;~K_(#45CsG?dNtEOfI-9m^$wl!;F zLAz~G6E4g(5a)lhn`Y)vqp|MplpV%B@@O0GpKpZ#2es_-T;Ha?;<45~j61Erd`rzy z}|1X`CNTzBuVDqt!SPZB~QG7A1`y#|Z;=I+j$gn2@P7VS7C^Y6rEH zv_YR`5H{z-RCSPiadflMDWK*+xqI;dbkzX@=3mj(E$m&x>%m|L-}PB3nV+Lbi2ri2E-hqQL~`1K86 zL4iUaDe_<$9I*4J<}&v+?*S?fQ2fz#Eo_M9OjsivS9NjJ++j#m^vZ5<@@!M0N4jEiD&=S z+5pCdD!h*UqIKraO{HiUsmM^{ho5b8zhBmn6O=BiO7Dezt%MW{sxf1Pquop|>sf$~>~h~& zGhFUfD4bR`m9d_`;#N3r-d|!em#aqA;vx$|j@ozSjX^uNlZhrp4P0x%IIAqU@Xtmd zRIoa?0`LPZhHsT5I`3$2pv%;XS#35^DbwZt$^tvrm3rn`1U{6+%PsL%MzCq)}lK?rk zo0W6f-ljQcAb%yZk81|b-8T4A0^@dRxcnZ5xIvVSM`*51ts%_w{c$mkzK!1W18B>d=fSq=g1JBxIHF!&9r;d8@agh-S?jFzf{jH z%QdE=SpsDM5*Iy8bPIC#?3sYN*E6DP2S=>R|{6bH7PxHz}4L4uv#BGgMmcZmmAn} z$w|A%SzT5FdCv^cxw}laUJ+9OTP+8Y&x%3RgJ?C0h?`J#^DK@za)3OmQWF*L?>l6DAIsNps$A=ax6$#3l00R+ond;62JxLk~WsPxssT- z<$s9GYE*VK@PwM{1tZna{_?z2MCS4O=zo^a$ST50G`+c85}$`MLVqs3a0jR2zu&UO z{hC-+DhZdg1F*%?EfX4cO&qnbqt3-CXGal5j<-Oq+lOu(Co%b3o!BQ>y*kK#@WgqR zPLKW6>}gFqzPz@dQ&FQ=95&Wk?JZvmqTFr(ROb8!*%i1i_i7?IIuMO0*v)?>oIae? z6BJh)Ah>VE*QEQ1XBAJyc2>LaEoVbY)2_i|p$43LL%;2htJ*|R^Drhi&{@aB#xe|Z z`V7s^LqN8iKqj-(!aQ>n5lBK`XmzUQBh+$cgzj#8qUwPn&5O(%U4r&wlL=(fJ)&&Z zQ43b}hv5ewW{5umxVas5uD3~ZXi_%6;CZlj@sGP^C?Xwqfb#M3^9~!)6%HmnVt0N$ zCuYo>OY7Ec4t_l+jDX!A5vGYF)^U7ZN5dS^?RDO5G)FBj6F{S*S+&#`TLsYR z;+Jfahr}DcvEr#VE>;Jhub7Tc8rAQb9poe4IauoSc+fP7%ubf&G5GC1nyNy)cInlE z>V_&xZk)1oV4B%%#bO=S4~amPR*@;yb`fMjj7;j+WgIMxPS!y=iPj4aRvBfM zL=LBd(-(}LjSCJJ2E>)R79Db(;14xujV=^#u^x*uV_6{8Y@W|TI_Hc+Y+2W0We7G1 zI}SqQD4Flz1pOKc0!CLYVxSP~gb#+pHa%i@mI)ZG4*+OS4m4NE;W$Amb*^X>_kqJ@VvoQ)bJ)r{uFON@ZK zkyZFz&_WH*V}qlHDVV3>BwOJqZ|$bd)8o=$j%GSPef>UT{B4uZX5Wk#VQha2s{d6n zM``#Pcw6koW`kq`38|)k6uK$m{P_qr0p?G0+4b+Q!H?&{GR?F!0u38nzqm^>Dj{>p zIcWs8!^$bKaCeV#e(Lhdr|@pey#GcWy6kgiZuPoF%D?)5C~!6 zG&9^%mL8_(Y#=#nbXe?p+@MJ+1T3u{n$ZaHXkkAsyiJJvbliebj#nXE&cL0lxHl^5 z)QNA{y;B;XHDysBkADdXmD0HY?S39t(+$GIcF3#{L9CDwGaFZHBlYmqL$K5W&ow7F zS~67MPm(wJa=75ki8v<$xVlbAYj8Q(DNQlVdS{g7O_k0VoZOfSd1{O=Y64jqIMfXF z875avT@`m?5s#e?k;qWiy-!$;(=ooq(H|hOY;RRL?m$T3}}7bQ(!>3=395#|)A=G`7g@ zx~DYE(;`v47MAMa#L_+S=e;Kr+(UL;F^DTP4XQgI*JU*(TGv$3KV38gLRBU2mx61oJa|(k72Sd6~_*3(s0u4 zaLMZl9MZj~jxCzFeCYsOV9rrzvWuxfn=vnvmd)1`!_FgK&PJK36|!O^hP+V&SIWTu z6C*9C&=I=rjjmr}TcbIe%Y-`1(>~i2x30lOI_SpzeP@JBR{veq$*7IvFAhyt#if-v zacQgM!`rBfot$k!BboV8OlZ^rc!OO^ZGgBMcgpIXsCEwj9ByXV7sl*bGoEXPu^fok z=%1^ba6fL|4B7exZZz z+;SanWRGz7plF2Pt{FT!^Nsf9v!fG9Si5q(Q=|bqC6TR1mV~c7Hqi`+h3!ahzze(c zq#BlW1yAC9;Sm89d=}7n(q**)E;+Hz&+0BSn@y>o=akIwwcv`n#847hX`8G)aNI+W zw+G5LmQOC~bv>NE`lDQ^B{g}15{Ux5ZyiIjy&*R{;xSt0X+{&+fvy>TM>)9vtKeFo z!HKxb zq$qEh=$h{9k+ol|E(wRJiarhRkL|qUVc#w$u~~!oPJQc}X|vXn+Dgrx9?z^g3=X%tA{tn% z|8UlUg$6r*=@y)<4|B1jlo^@L&AU16+w2CXml@FY74G30*ojOFB;(eYVKK=$WQ9e8 z@bpMtzmE4``R934KdQ=wf>>fWlWC2Sj-SVWY?xEEN?7GD`Z}Ez<2m z4jo=%-#T}dbR)@KrIE_BJ^z_h+-2S}GAc42U-s(QTs1&PYPrUL*GE#7w^H!!%B9h2 z(fA&Qzn9_m($`k5KBU)Nt1@#cbo{Bax%Mzl;V80gCb!Uv2a>Txqj)JH$N@8)ETBXu zzDot1fSZqP>wB1XR2tB7lS> zqmcFXdgrE9=U#Vw;sj2GfBY;VaGKub(i`M#4C1MWuDpgkEyn6Yu!?{loQI7J@EwF_ zSd31LNvpMJd}%mtWT24QUxoIxM6uG^b#gIAN> zEwqD;yoiezFK$97A-MDq$ju}JSPjM?(37=XW{u^{CJ5}?? zthxCF*I;Govz$CDXQyuT$*{-h@A&Bf&fm1dO2$KNuPr3FeM-@8Gq+QDWR)3@B=KCb z23d7uCnY+q*TsHw{@eV=y+GvmDK_;^)`2!W-HLDFaWc&KI)ISO_IMrVw3&1JC44r3 z6J>Pxi4e{?BBb@)GID4&$?@A4x<-j5r9?Vue1J{Yi6ideNv>RjfyQvxpdkk8DR+`B z$$!qWNz0^j=QdHVp|UbbfmRcKi?JEuUAj;+QHQ&XLT1Ywg;XAimE=(-R5x$=bweB#iTY&X+2R?cS##jn2EwVZ5?G zz@OAJ4L#B%iv{JNqk)FTL{Q!3tf_nc?Xd-Qc9U#-B4|Ms@#GnPU?1D^tkppB35gW5 ze@+N_I#}BLR$bvy?&`yJDknOPV-Mc1?r&fN=VeY6Nm1C;4f*PQ&a!yqK|VfYG;6mgo?L9GaH0PYW}Fq$c%pH&+!kuUhI6n=lgt~ z=l!(tH*QaG>@Nv1*>@JxzyNM`bHY{ODCN(noGLyRj*h3i?2ibT*Pe3q)2Ry=7y2!^ zU;*I$w<)$%Yu7|d(ud79XYsCtu3bsNc|X;x)!c4rx~?ekFntxc##YP2TmXca0j6*2 z3XExA*L#fTlbr}MKYf5U9YK~@j{z*iSV?}asK=iO=jd}G%^IoN$W=kpUp_&>)_e;o{*Xpd8WNJ^f#O|0V^$6WtozI^7jfkewsY(d(< z(7sGjse}qi9HF=Y`)Nn4P;D%z3Z{*uM4gXzQ02?>65w|#bT-D#wo`H>xpWXrR8QW# z&(ttK8o|2~Zk!*ge%76`!AgT_8`kYQ=GZ1XqhH}GKQbb(OeN}}>d1LWxDP&xW*U+P zrZ-i^gdfz_1PtuD$vD@PJ1Yaq`R8W(-5K*ozJ}kdI6iUm8}C(F+3$WcfG-;P2WA_= z2}=nmkaqDZfy{#|aE6i=Nr{5MlOvfEA{#mJZ!lvty(<^g zjaq}A;uRC)8-11zXr&kMib=y7ZTJSwwk2L`gR>lx?!!5S$%lO#G5%~L#Cz0l_r*E{ zQ~e;{)hkJVK*LO$M=~6vyH8ZPpy<&m!;sxN#q}E_4zu~#QJJV7STvDZ5q#IpOK%_K z7a7PoIo@$vS2XFg0E9k3W3Wt5YLT~2GGL5f^shXYk5l>*YD}Yz zf{(m+KVNyYQfnyK{L2NK;I*)Kf7< z(3N)pQ-5DK58t&Dy27Wgp4rQ?+d$=TtbcnBS51K~T^F?F+RnMLbf}Ws{!3hx^TDE$ zIzrfT3R)3yU~$vRI-iDLCm0`am+arNx2Cm7u@KTlwme9YZ zfnuVYjC^}b7drg`8A2GDd1kGDaCKBr>tJLs0qj*OL0+B_oF7ZvdvcWYR|At@CIQGg zi!E-9l(r3s{ z>+4q?$udGUOYto}nP2-!2=*RY&2^GB8RG^>A~L`ezr`ySHdsS#cy>yziT`0wMA-e& zu5g{EdXb?+a--4_BVEIxZAAfx@hh>G4Hd$=RwgDn#yz+E#bOI9iBPiJH zw#&}gt%}@mo`&l5dp^p^i~NJDcKP@|F+b!sc!X)E@fpXBOw@RdOXFai3d?6Iigr!w zZ6RE;QdNiX*4B533ci8pm{?zdStIEMDlLWr)QU5tx3kz;Xp`SMWrpm*sAGOGJf(Ud zGHu)xTOfx0x)DeCobw}T^=JCWN7?CPrtN8BIJlf}B+wtVlt#SO=}?XVwdRpdkrl{P zz5KI!R>Qb<`Sn9}D5ROauz8!`k0>jmr4iXI?-W>d+1My%-<|~v)}onmh~&*e;9x6N z7yQD)&db&MiqtxGz8lieB{o~lonegMSSC>idZT|)>-Ah2)>WgEju4r__->Y>5Jrvp zKOy~8#3Tv}smHm24{RZPrY>bM#o&(4Wl_Vx+twJJaewBNgkNLlAE(E~c!1B{tU2$T z$KR+dtIrUeQI^N`8QCBG0}8^OyE%mf zyuqSnZpO%DT?BKzo<#Dz)@XX)JsYRrBJlUnb^w;w(@8UfcrL-2~WeUt<1iLeU#edMjiTe`97P5gVi_&L-h0r$gQ`x?}6u z3fa#$ucmdIDyH?A1ZG{fdMu&F$g`njPe#~tH!?*H_r@sChJZMxA)6}C9KC<8*vpzH zwr7gzcX;-@W45j46UA=E#nke7PUaeFoOO?D4b|J?E&eBUMA879&_X0D8D;h5`A&AW znE63_`I6C7%f;oX1CBf|W8R2;F2Z>84LnE8n5?b~2CW%>P#_8iDs0ooV58Vk)?cxH zM79O1DA7;&xTRuiJNrGH{a&o{y)oLV!rrK0zc9qg@&^a7&I zqGVe}&#oV1m5AA! ziWyROUCMa-1)3~Yr6!Ng6*A3O?WmJkdFihYB*fGvQGeyWLf6VF9tDT9}2m7lcJNL&d296kGcLC=QyQV$YD<+H}?NZW447 zv%HW_Y@Ic7KHh!tw7)g1hm?%aTszW7K_D*Rr_X`t9}7&qE_ZN^Mh|WQyANa>#b?T*pmr z*bugTXq5HHk6q;aB(9wdtG5RCSY>*f+)0K}pij@IGB?>bA096psQvrFi9mu>;LVr; z4Ai6kI>N91&}$<(Hrnm7e|D(ut3S2D%^RlAue0GuMTwYY(mRZmtOLIQYo67-T;i&2 z!uTUn>H#1J?1^s12H;fD*rbP5tm08&GHNr8u*=%hij4ICvtonMrZw9vMB7J|s;&L( z5(F!w53+cB;^>(`BPkT`!e++!8M6J9WWq?r@1Nh;idKBo!%i5eu$aM*N7!3k>@EEj z@0YUj+HKG3ZDTzg)qY{18url zqY_78qn@E&laUptDlj|B0QSZ)_Duh`GM-Mp&zWM^R|{anVn)Nz;`< z8)q_?`8$@}=`zyvRM!;ILCiYbMatPL%!bo1eLVtcgPI@JmDJ(V>Bio*#BO=e)PX8%1DfDEBr@mG9o z*VKe08S?o#_2oPD#4+O7v6*1)yH^FH!09pI5Kz6dDR_TxRqKZ1Lgm?msRtnAXVEDU z?Jn*GPDtmtcAI^Hgdq%iGBVXfC8jM0WSE{ z=6K5EBkeLLC{ZE|1Nat~6ab*_7!uGz3fkcyPJ>_s@X_B00PHFyaDybi_2VOBM92U* zr5}{&9VZ^v7OKO3Oz<80V0+Q1<1xBDNk8%4@7DyQ;KUX(vQsCYy8HP)_#bFqp$?wj zVvX)UAFqQG5cr&No1b!0QNpCZ>u>tIr=u~Iev+L}cX5^mSd&^-c0v?XASqW%np)RtU&KdrLuS>*S>KfA6AaZE=Y?8Kl>wn1^q<4w}rj7-tl^9&j$KY8n9_AS@@VI(*Ud6ZQ2iD zOvhN~Vl7k{v$2X|S64Oxrz+-*@u68xTMjyIpXKx_N@uBOfj#r#K(&3H(%yR!US+;4 z>iJ<`4s71&SciX+?xx3(-RKhk^_AVMg9}!it3g6*EdUd*5+YY4wEO@0;4{EL+?z|* zX(YbwJ6H5p?L-b%Ani8?W^NrO&H>y`8dVyyje%F6_Huks-3fA6yM$g*aG4zSm!mtI zF8~qiuAY{Jv{mi`bzpbMEZrC>C+u-!;DJu^ataix0|ld%@wm2{Goa{XP+n;O-ctC={lB~ zfD;KasmFSHjeh?L!Vf!s;KYuk93n+w1;qUC>Mz|IX!%Q1vo|I-u>Rqd4;}Mc>@(XL zA{^mTm1hi9owuEwQa(aMYTLgKR0C1=h-843gtAvC>FpZ(8ROa{^Qgp+Y01|p`Y}K<0jI0HI6~<9|U2&s~ zssOZ~&T6;Ek8Y=?*{;vFovhpej84ulI8t6#7!iBpBX(Io8_z0VRm%DZVVheVC7NX!U~i==@+UDgl?k&-S!QQ*(g-Y0E1&BMoG)e;w>ajinX8P9 z6~wQrlx(wW?@*x~uV^@7t!xZ#yI*jsa{R!AYwfKY{67_kFzR=mj;TH| zaIcMKJSu=g&J`(vK2x)IugdbrzF%=^@eCijwQ>ZMjS^CT+@VE&n%`o*4*1s>a_FP1 zhiD!pcxE6_J%e7^3oPN-=!rmj38}Zyr@C+!nPakvAMr(q->v$oyiy>gQhZW$@rJec zCnW5Lv9AgZzf}GjI)d0|71MFI92y<^WFMUFRaO6zh3i;_7KWN26V(uno5tLNQ|&;B zZ-~F^{^8MTp&C}Ull^$@HTR|a&je8YfiXfisl_(S#h7QXMLcQ&8-bUv6(9ScFhB%) zE6MYNpYyq;Sk;RN9h9 zZ6<8T;_%JV1px)<-^KqqKJFQLrb=HyyqARq6=3Nf8*ZiPzMh>}@e$@I=vkirERudT zxbmni6JcmAgR{J>RmvslhUh$unWrviE_EMCFE5!{KBfNlIwxzonUQ)2UZuCKJ8ysS zOBh}ck53WX7p_=+CBw09-wzpBg=!S0$8Y8U@NoLR{qeRkJ?nNPBgJUX5Hm#H_;e-t zZW@g1BZu|>6AlneGjN4%mK~{_`_X5&OZOd#Ik>xuP>5xdbweEeL1I>7=55mX&rniR z$g;=#qZ^MMN^Re1t~y=68)+wyQOLcbA0gj*j>Ec8K|kt4vm ze*u`fo2c|`o5BE!-hg!-_xZ_mzM3n#kx7ZBQd>< z!5NC~k*2JE@dhls<~&n)#7h-7u{gW2+^*JkV@Hvd0Yf~o6VUVtSJ#MB=n%f1?l}G{^nM|&-zh7==0h7Aj0ZHAtGS#Y%RFb8-Mlq!1D{=E593lRVhE(y0mT9g|m#}-K-h=5@Qs=X?=rwAZJxFlYp&Uq6LX|Sn-hCi--$$grwp%2Y1QxU2`9aCZ^$>cB@9XT0&*nm zlRlD{y}E&R7pnUP*@;6vrkPV751%G|Ld{MsqNp@FcD2Q4FcRzZ+ZELXk(9nM(-fu^}QOn3soZ zm~ZUU1s#049gp*VXA)Ozg8cj!kt!TlmW$kqWbG_PE4K?m*CBo4#+3c z;$bGf_tmfGR_7}-kI&B^?F5c3K|y|C*(iuiz35|Q$SefJdUxevEnPphlWL0w=4$i8 zN?j0_U+IDnX(6EW3sXB;=z2Bnk_}L~z@z@n6kRkgGKv^8`H!vgl?}HHvXySsAvW0? z&-9vIr>};YU(&`4Vzlq9JpHeD#7|pj%HXD%8y&#}a;q(PBceqx2X?3?A@Y1b1D&nj z*};vg*ORjfuZr$`+5bW`la(g15UiOJ=PB@$FqDqPHELxR^SPXPTY->tzB+@M>mHF{yYaosZr)|PNwsg)&8^Zl_m7SZ|xyr^r23g*9 zZa;q7?q6+!guc;So|4;sALn#^hVC{`aOtgy->B}?y!-d_pOz&By*Rrg?cd?u7d8?N zJIl*VZlIH^Xu-71sP@Ln1j%^DjZGgKp46)ToaKLAb*tIL!t&SoBC;F`N_JI<$B+a*w*;YUy^D5YBeWvZ6`Uu$*F$<{zwM9Iecxp)YyTAV?C{S2=oodIk zH-~)R8_V!};;nGdN3t2Wp`Bxnqz1Bb{+q)euk_lIhXMYq`k`^5&BzEFRaP#-ox#<( zUA28DD=T+JutmoRJ<=%9xC!RmdA5-F4HZ~VBRlsSttk1gw)hr}^pO`%ya3i_ws@H6 z0KlF@XAN#zH&j5R*D;gkaCJfO%O)APNg{Z+@>&gw|{b6a!!8 zSU&fW&H#MG0M|H5&pMn5nz@kPTfwZuz);BYl!bH2vB3_IwX@$E3&h^I|UI)cVZ>zTQEwZYsbYI!59{t zUG+O7+=6p5(Q_A-hdx*Vbf@hHeNE3B`PaPcm9I9OFA_JEBv0++3LBAza}$c=TPTjs z7{sQ9rZyw`wCKE!mBTBL_dIERMu+krKbE@LRQ+t|kaSpW zmyPwa*5+${H@5di4e1Z3>YIcQ1{^Vcjr`SKGGw8y3wxsEt^mT3+f=92jF74d0U+Kvf^oABM`A}FpEOW$>x*Uq zMJ@{9qU5{Lz`_7$pnEka>DAS+b;1X!L3u$?Yof&t9F?>c?t;cD<< zB8j?2`$F?ULrf*kh_;Xxfs6N^lR^_QlDo-R7v3JkoK_A6ws~5jTUuy{NSw@NWE_t1 zB!s&1A19ac#d=adjUdamV$hQmkg+wf?tCc0-zz4>CoizG=q-d2fh9(qCBz#TktoRF z&V;x)&fw_{7$CPa@H;pprl+2zDoOdR&kSs5_ts8-VeK5+(s*{V$MtWoplSApk_-m2t ztL8YTO*ws(8K-~?a{3(jOOxaqoF#vGdjGS1qrxjxLn2|~;60xR5-LHTE{i_MN%A5L znvx$;-}Q_P$t`Dh($gvYf9TbdW9-K#%YR^9EoDpMEM4I*0b9mJl39q-l;)R|Ly0!w z-DAFvbMIgJ`>hpX!BboEgU0QphNh2x9Pz3Z9jWzTO|b+Q>McbV8Y-opSs&nE`ii*CzX=XWuUMH zK%@cA07SQAdKt0Tx&%g#AfIGDr`JM+po}`41x)7XQ-0gU&ZIV64CV5! zvUH)BCg(v!F0&QqdT(D$0Y<8*M&)j#5&)aYkV=Zva1hQV81@5jCIKjMHcmafcwOHH zCf4EI9FMGx?AaGE=L=ju8Av`Sg`1_l#oLe?JjL`Ak_Qyy4L9MSeQM)NqlWKB6)vSIJq&P}g#4C&1vUb+e!bUBGC>4~)HD7^AC(F& zJ9iv^Fc)ujhr%=h&Qil-5Fx{lo`^bcDqu_%bQ9=^QrezX!D*v#%81SAgcrT2NE?9{ z?zO2(ZHTT@RSzXbLh`x{S~o2_tP#Hlb2 zDjLp6ssV#VFeb8HSK6+tS0=m!F@?2ntmyV7T>b?hy9lb3HgSVT>;^a9-I09vZPi6h zi1EiKZV%$#BN-+mYqa7X+goGMThAdW1xQsL;qj0p5=vdund^&CR^F!? z0ZodOTGMbCdjsy_X|nR_vWhgJcQtqWfvqF#bhRzrY~!|R_A*H{*CU}ZZ^YZ&mw>RD z!t&_+2N>d6^ht|{?iC5!G#J#Z4K{d!#whxmA4=i$MaL%-=P;rMZxMoikD#8CsXoIH z&!d3MU^l>2Q_vb-+=h7%67a5_^BCeQJ);W@&MTN-)>j;j+-VIRwkifl)Kd2|n#(3A zwhH88+=v@ZfP$J05YRwT>le6BX@**nk=FBRFoS}WvraGe%wBv@)pP0=)zC3_w4?(8&{;ukU0L`T< z_Ro+7+I2K95{D{*<_2nhynxwCah}(~R1CLe=EaYQQ!Sb!3xfBA5i)TH(t3}}=nNUD zTl@t+=Wsr(BCZ{4U^O&Qrify7IMhmA)8LB3i#THdTK9-c)m(6ZsP6N{=ZUYP4L&va zy;}Us=)#mLfSL=6@9tT)FL#z-)?o$}!=zyMw|e%(xa{@d#h z`y@SC-8aZ5mK#qLo#pXxnz<;YZcPfZ+{^qlr4ClL;r`yvDMmC6f$5FZiR+6I*Z8*~ z4@n>OlHcOro^oc!g-?O-yS7wIr)pEiEA!hol}E2Dp4`bkaXt6Mb*G1mT0^D*yh>cG zvn8k;OS_oCZql$nWh^RU!THP@rj|`y5Nh2MT71r=K zJ&8R({)w$j;AJGl+sa*@ba7|mOERU!Hp%6YJtrk7y!jQFmTyx$w@%aBVp=+<_t;V| zlBstyQikF5m$noI{AjTiYav65CFz+NTZ20!w3F;NkLV2!F+ktEu7FIBJhSuT-)yRj zHmTAhX(Jc>=3R2o>iS(1WzZqn=YJ^_K7Y*P?)zMg8X9z%~&!Y}U-GFITNdZ71S#x>iQ=`idD+pMG zPuw~GkF6aOx~%cJTBet9pT`lL3U2i*It_5Wc(<|uQc!?0%YA~DI0+voap%~gDPe1J zUHIvl@A~8dlpDX`(8w-&P%nXuJ$~K?M(R5P5L}#x19`NEgq**i%eNW7kcqP`#C)1% zSVy9EI(1Dk2k_3x5lM4#u83HmS+QVp9&;ovvx^G;#Ou15`>5a#6w2#np6- ze04)vv+IoEmk6FzcEA7gsg1Ka)SigFbZfvpex~(Xjq(a7e(rGXJOwbDM|G8Ona!cc zuwDbc7tEaVwfN3(pT_RnE8ueP@TxiX<|mmpZxNd2dVcN5niNAtO?i&oNz3|hYL{(V zb7yj4M(SPb^g0qAdAh2->T(*_XVb%J`EaVUGm(+;^rc1#SeF&X-%IYkdpWCU*_=gD z+R3zC9*5No$I|xlfmU2r)A=$>_KVu~^S?VXA4LGIm8#Jo`c0SipUzm7LGzx&+#RuN zY|5A8npiQsG0?VY;?X(JzWqb_^YT1_`X!n45|m+(9LZ4O8M5ayibgY1GW$nvYh}*P z>Gkl&;7%o={V(}cFyi4Etm(NKSX++-v8Kirjg^-!;<*Us0K~cpBWgUWE!b&_GIl!o3nU~R^vh4nVcuXDm3mn*sINWfK5kyao4ZA-IZzw>1F>_IFgZOu!5VY zw}Dm&?T9kqt~!cd9$kDL75UD)F!G5{CEnQG;S$CW{Q@;#g>P}VWkG)6Y4}TKBXk`?!v$KbV zEV=>JK`p19JeRIuZ=DHhUFZ+Hz4<;IFL+B#(R#hUyrhXY1VSp#- zZ8ZBgLTQ_`Tda0{IOpfK>pEBA)0cx!I;ZICQ`|cebudu|%PpPLT=3+Eqv^&!y|4EM z=W3*gW`F@k<$`G#{T(DRUK5A@xESZsZ2CSR{uf6cWzFBH2}QL243EfwUk;t|)?#4c zH4p9EWVeiKP!;&!5I7-@j3CJ49h7hrea!&H`R4==Cq@@`{Ex<;sc(FfLIJ=(63iBZ z&}9~sc+HJh_Xf@(%ngCox<9R-}NS%YzehPoZNwWR{j-H~4#0kW4hp zEh7Z@Ho+fHsbQazp7gig57#xeXR4 z;FPC>LPqEYQSnU9OE=w7tPlvg+E- z(DxTEJMRz8tB`6ZVDquVyMh&w4ez&~F5f+?&hJT00wD2_d=-I{OhbW@WOZ~It?0l-g}s(H+JY9} zrY)LD&`di&fKWv6rwBPV!MOU7Tgw#8b7t(%faZ1|D47RPB#5eym5UVKxy&G=8Ds@5 zQD|;6l68lw=9tFX$`>rwNBGDW_6-T8E9=Q^*pv35YWYbwAdHG0YwY5RY0AIoxmI&O zKyTF42rXz+xw2PBGPyL^lfmvCgKGs~nUWNqnViA!%cX;*{(5xqWb`y@#l66MypfrF z20nTqC7%Lv`Km;3)c5gXZDq<6y_^x$s{3=kaBpPea^a@~gA8%BQeGKXsMF9yh%;WB zXv>I+6{Q|%^UdhGbyfCpUGZP}>~mSTJ;2z%hUvc$JKk6>jgJKnbx{y6!nxNR0Vp5n z#v=fItT77I|K38Vg9h+t{%U+9%%cn9VPBdfTbpFnWb7|#GyNkeKCS1Z2!sbm#JI9OmZYdIH!7bhspys(d;lrwiqOpoH4oArFi>%e{%M!6 zg@WQn2c^rS!dXm^oj$6$)}&`oo~8}Te$AkzQ~)s(AZj7@K`6k8#L>?OvQx%Pf^~*N zFyD(ar8t>1kAQsYX+NdK+CnG_Uc z*~tjkD<|O_SW^c)g7K4XTb#utNAy74tTs5X&I@VlA$AmDljpm5Ytgx!dCJVM(CnCK z3bfgkc7f{~Mz~nRZLWg--4r824Ea8>SQp&7^8QW^EO|ml?#wuBNhN&h{A%Pz7c}Yj zBk1M;dt@m9y~rVjiC#EtJRfvnw-^#1g=5ij0W;t%t@K0HDYS`Y3U5 zc_PRiXlD2kyXpL;RW)Yv%{4;kI8ZrY+<|&%tG#uuLn(qJ@uM~gZ~Rv6y2y;vKcE%r zQTg_gX`46vZ-q4NW#iX<>t1cx_a{63h$iTg_wmNwp(61dC4XzfVd+!dH`90~&RYGU zFCGrCrdYn=uU0~S7eehZM;dEa<*xX$MrRKbYFf0HypTU`T5Iy&Zw)T;;Jk4Y!zXar>5DZ>VTB z3f1=3Wu2U}bkd{H=c<-JmUO%r2gMGBT{*g)SER_07Ms82JRhPt$&VQmEFK7v`&x*f zPTM=L%IZ9VYa2Jq>i1kspLQVq+N`azx1Lp`zxLa5t0L^>8}s#~rn!7iN+{6Ge8rHi~ATV6838A?oy_SLpVSiSa+C z{P`wrm;97VBXN)bDHorKl(-tZhcSMh3vk9)RUdPuQ=A=g~Yy>|p$fCTcc9 z(CAExyBFXdCm7N}Kg#QDtB#gZ*Q1GV6xbZ!CjQBCXn!QwTh0Ey59GR_As#~oz~OjA zDh0$h2FGjYQUJbm4*!Ued(y?OE1&~e{FOwt9(btgy$v*H+Q$LL#6oytE%Wz|7J90V zahR`;@y00B3A)2`zJ3Lut0PXuYv{RM&&_5zh;r92!ggGEkr2MX1ZRkaX!gRVSi>X+ zdy>&2q@Xr$P?BGhKWmK zE}qWCMCi2#DR6;#mv7T0QtNx@@n*W)+i2-&ys(gzA&r7@{55@f20XlBm!D9>6MLj> zbo|vLeL2YM7>su|zI8=TEcB5bF$sDGWle(%8n1 zGaFdvQ{;tRSH5k+p}M96vAzD~z=Q0S`DI3y#Tztp++7}TL{R{3Dk4aaeO|h2yh8E%P!qvn5uzD-!PL1~CFQ3<0li z7e41Y7C}HTNe} zI2eDKL@&$z#sK;oA$_!}K>qBb4$V9BZlkT{cM_DS3{f?Oka8CPBYT z(B~4E4+)TwEq2XH8;QQc#7kP>#KUqs;TxMnZ*cHC$Dm0@a7(VG;<5jPb(!-hVJ;q` zY($iD|Hwg(7yxLf1J5|~<%XJveYsrM^dDNUcNc=q-9K}QUq`FF)pm z+wM*B!fm7Xj}$k|!_fISv&{>i>xbIs`zT_3F5-*m+|M+Cy%}St;m}GPitUBmwe>Z> z9Il>;6cl5bMgXF&Zrw1~hH)1%^Pm<`1b{%~t$Y`Z0dtZ&=6;NE=xyaB<9OSzTNWQ< z5Cq(XA!z`6HFY}`mUWQ)u&&G;3caX7lFe{X6NFOi8bCG2yRyr8=OOjaLxQo;D|=*; z8eAolt}Cjw=2#rb1{jyUbcD2+@;*r$8X1ZlneQZoZ|SL0Zf=s z7*-NqwLo)pdu@!Swklw8&`bU|IP!Fi_u%m2Gv%VnFER$ifzRJ#d^C?)mv!gD=!dh6 zLi-+JJo(;)LD4vlUj9Njr%@R1S^CZ(?Ng8JoV@&xFNB6qglY4Z$0UjhNm1(H(o25^ zJ|uIVc!hVVPaYmZ(QZDf3E|BsTG}lsP)Qo(>--0S`$on9{^r6NflCgMCb3l`-UXiN zl}qAvVh8xYh7^f?5N(o6n&e`UNrg0`pUx0F8+n&C;!h_8H#9QGY}GjcpQanvk?$6t zO)gGKt#3>|wSWgH6T>?f_Zcte@OO=Oc?GZWE&mYfIG#Bl)hx>`+%wO`n`Y1~X?Y<2 zQ#%ibB5`iMv2Zg5>Zsj;8vi$q7n(gxo)^5Wt9^L(`kR)SVuF3WFqcKJ|N8@|nFQs9 z&w1(ErMbWBlVaae;Ioz%_GWsvI-+Qo4^;3U=wPZ~98exJ?rvP)=lu91FG zdSE`xiZt##d6&gPxF~r9k+m1KV z*v`!qXOr>mawFSyeO&aC@O>`IvQTvVot^4}6gY!03A(5yj-~?+A+iz!8CY#J0cUzS zZh#)?hjhT|ZxzAcKI0<ZSHp?Mvn2yQxZKY@7}53j-e!@=TkQ{=_=D77_$ zfdk4ft9Nc+^ei{%hhFBVjlp#!XCg?Vl}r2_KGp~Ceb>I@twy0pB&a@GYVBBgu2?=| z-Eq{{`CYw;a4n9%-d472!m|Qjqo;PlpfFx@N}AXe1Lps2Sbp`-Wu5nyX5HVl^v$k+ z4c$x2I@o(dng)e$C}9(LCiz|KDEaOl5VQ~eu`n>$Ws0luNMG-YzYL0BWAQtsTkLXi z>Uc!(1rkN&eq6cVkdglo_l3xNb&axa5A%5@?aTc7=pY&o`k`cZ*d(M3M_#=U?gfv& z@+7ZK`t09}f^JOOwd74dfiC}_V1P)@sYGeAx-RC{6=L5%6pCWd3+P~Cmhestlzoo7D;$F?7`ws9K8}1xxY)=#ztEw^E2F%c0(o_XVor1F zO$x0uU4FIhe{^YC-qp{F^LV|k^p-Cl9hluPehhN_#*U{V$5~#Wtk2fBeL~hK@5w1% zlaady^8yLq#pBQe?7KC;`cGKPd#jL{79exTCBfE>=h`kC{OTWJex(Y2u7;TmSS4I2 z)(GppkJ5)ThF$zuz}X&6lgYp>rQkFpkc#p1@nrXcuUR8YPTyPmAH~mUVP?9xmEQSN zBKXiCiVULRy-$BVY3FK|u6v?j{i#vf60qj7J={^bSeq`^OalB@*>h1!Pq{`u#q=Zo>36D8gMEZy_g-@7gayqf3ur=NM$V3T{?Oh0W7I@5DrMPkzZ{bT)#jf}OXG4DMinzASt00WlV_7W63l zy@JATQ2-kFENx|%NM$>6Q;dH^1PF~y)+!K=&2J$(**+q_3HdrbNmls;MQx=_gSpI_ zxkrGRiN^&#+PHx>WZ2QK%GDF<^zB3U&q5fLCq0$WX>-;0qqs2R2jCT@N^^bhP;Prpn z8b4NlzxV2nLcAQ6Z*%aR`+bD*&2Bw(n68d$s+ziN&O0d^2A5PkJHnGF?>m5Neu)`c zXxG%jhQBqxvjxtei7qK=WP~IXuo@Keqd7YTYenbHo^W1G*1Cry`-w zA;mq`f6q7P<|h_)j=R8>Rj7Q*`@=oO5S(~hYRx7KYUa93RD36v^bG*lBzndg*}O=A zz{P0bK!4dcI;2m&uF`Gg#6AD%?dsUqXK$vxE;EOSvqh%N=%ir3YoHhhJPe-FyqYktXTE*gNgq@$v^V;^*)y!PEuP%SywSAxDCPwdH z8;!Ai;%ZR`C#p41%ZX@Se0AT7BUMkL>gbX!pZI1%+Y}?cVCKfu>t{xl=3B2euL7(N z>3c}(FHzS0i>r>VJ}x;H7hw?(~RJl(GDtK9wewu!H7;nn}X?7XQQf6iUw zZXb_mP45JR{FAXwJxyt^N6!k{CZ6eOt^O%eh^RU!?+nS>yo{X5M-3&Mp!xRs2-b`2 z>`VLV#hu)X*OeJ>UHWDizK(R%C)Fgatl4RhnQKT)YpyhTNyb^)qD$2z)X3oRd@q#~ zR5uFhVj`&xw3pRQWkuvG0o_+0t+6s?>Ma_iGIL{g;;M^-))3`1FA!q7xS>LzKmZDW z6blGAJ0Mh)mnp;m?uYergM73i8d~%5n5im}diZ|Tg}#ATOP8vV-#U3QNJ8?un0NF` zq4;{G72!}HCA_mrap)wcMA5lmA>lT5hkg7;(dZaSD_U4@TH0u;4E8f9ZNsvbomSq;Z}gy9{J~~~6OFHMNG7?Q8h^Y@PllIO z-s<42ultR~Q9L}I9VXv|TWL_LrQ?NpIBET*zbpTkQ8r##)j0)QLpifU*7kVzpsPeI zdD3s^&t6L8n=bqvbt-p0Uq&^UAfHvp>6$yL3@J}IL)ERc1~{pO`#$WadX22~9&pA^ zAR}^;xGGl!!gBiVv$gGzJi!=Cr!zsN8{}smgZU!>yu=9^HVp}*yg<|(kf3$PK#d_o z(Z|TlasQIL+u*ZOItTZuH#ZSvEyTz7acOKGD<|BW=q&5##En zJxwLzCf+@_J!w>3fnl)h6G)r3_5T!|d0Z0t|HlW0TToO~JW{|5ufiKU6hsY=3QY~m zTD%g=3OlW=hC{p)OHIoPkIc#n%gT;zLCmbI(5$SiC9~GHwy@LM*6(kAkH`GY1Lnib zyyx|Ky`Ha!)kS39z9(0g(f_0G^;d9Rv{pBzR1>TrOpQ_$|&T4~2e^mg<>L$sLs)PxOXLvwsJ0Bs`p?0d@@x%~|h z#=ks~7?h5TEL?gZU`BFj-oSaJi^RcvwM1$$2%L7$fT(IZvRBbe^eg2BCv_>^)DUR{ zjpxv7K-vlE<|_+jHuYKrGfxytVj>9roD{4jxhQRDVW>e(yw;pSEfRAbZnGZWknCM` z*MLRT`Jc*2o_)SMDaIdqiH$@>Qn#dIioU#z<-L)&Dyy$unp6 zfC0N3lDm`--VD$ph*C@9dRHnJ5D50#xcybOX#JWDDn)iKW!M%*GYw*80J(RCpVrlDoOv^Dv<9?* z2uT!mt<^ylpTguIAKkR01LL!2%B*4kt0g^Qd+>R5l3l=ZHPU7DEf}r=$Zsx|+(-0U z<%o1)?wyH)yZQD$Ecw2LhLW=ZxkKVn!m<|&dLnRgQYD{Qjuso1 z@MJ*&)LQ#efO(^SZ4vbWba~*WqpnHN#A}({)v8F0Psi8cosGHv>7qS5zB;=Gm$`kqgxpc?2>ED#LmBCN z*E!hO#2b)}fq$czk`V$ixF2KGSvkF+jQ$dwAibg|<(AJm_(^y}xIdFpM}2%^ys$iV zCDQgGxg=b;`Bik1&n7{!_Yqx&xu|{8#YpWxb35Z;0PG|mJ8&BmAB0$?S!jeviu;6$ zX1(!3C~ynyXB+@aH!S(yNhJ+pVYzN{=hw87^QHoaaZc=^Wd0ok;|C4nO2(niBUx(L zG)cUHHXl+GZy88N6|sE;>o|(lYcc2txq}&)Dh>w60v7tiOf(FFV~K}=!`~z`HY~*Z zWGkmEOG|&iS3;0&B;<%}b0k=%$xHmjE2kHQ18fqS1y@2M7tziP$|#8WhlSHQ0A@%o=Sv6{T_vTJT{1%BL5CB3cj*0ym6G zMWg{+4ZtD}zJP__%t3G<*wO`ht1F!#6cPaWWaM#-S$5>kDfC*E-EVag{~ZB~f4L*g z$@@QDTG%GQZwCZz9Ku-@hNNLClNer{9mh-DZ)AA{-}IwrhA3TOt^%3B!Q-P0y^d?$}HFd_HBS%vI(#qs)W0?lM9a zMonyMEfJ@UTA~2lKZU%Wc}$9`TWk(cS17jQXqaIRE=z*Xg*?`1ckNPHz^K!~79SsE z!oM8i3Oz<|YQ)0JWdt%zq;?7sR+Q4aq=c+w7^T-Qf=eF`m1}WtIHWfsB8(D*IdFg6 z!b!Ba3Js1$2K@m@mB*FGR#7vzBgNE~XeQHi9Ic9P8}$i!3IU9ezsc@0}hV<|MmVKoLYGTl|8 zEu)y-Du_N^vKliBn_2qiCbBn}Ow`$2);o>TF+K*7V;7R{fXNWhZm{6Rr_JDGv^r~! zMN$d63;7m9-viKRjgoq4#5+a{r2*xp%CjaTyjm*@jEGw*B0_`AU=`H@ggg-@Rf;Q8 zVORz@WNE$kKxR1^Guw!qm|EGEimgRV z+rjp1;vGYR0rENUxs<#y<%52OtAV)7gQU$VY}Qw=RdMfZaa~)2bO2V~zv1>Y2Ow}( zk^wjFzLkI$b()2oZ6ci2@Gfx(UV@Ak=;XXFV1Z!Yk?&q<6PX^{qgX$FD|-2xO(3!$ z`P7zVyo`k{1Y(Ho$gBIAOrUTOj&Afut)b!SX}A?bGZ?|G+FGkRMymJN+`3jM+UUKu zp0lc+Mc6MPJS3CyX~ebmxg(K&p&oN;C@Iy{Sfjw_YOlpOK+=oM|AQ>z*5Kau`Qrp7 zVl5H@V-5xD%cku%Pa_ig{237O=;x|IGUDw_vULO^rEkUiX6d?yu_eQI1U#Ngzvv z1|-;2+J#vV;UeryD<>QmkruIVPgFP`7U_|QGy;(_=Up5Tk)Co$ZW_coGXzQN>X8j_ zyC@zL$QMxP9ntc-QS2QEvvY<+O_LK9D&i2CRP0DHLYK%$q7EV!8G#%F>wg<*|Ls>w z_CrBvF7M$whDPdr$0>0%V$%)e_Hz}!{|~9ds68`+E<|rhr0yo;p7aq9n6R=ed_N0m z0V8VdHfXh2d{^);4z3(-6TP(9PK#J-Kpg49`lx`@g#@*lm#!h}M7SpiVTOU)^%9c< z5nqX5X|;0QyxJlSHeZ9)N$@u`SU)vR4-q%2FxyO+G7~0)MJ!iU2~&uDfOctciV>vW z;9w-v(sOyl>4S_u6V}wSp@yAsTg7?B#;hZgz@>Jz=fi&mVp3K30kq1#7H_XXIE%K; zFuM_d4GH$WiNvZ@fDu`*MY@;}&MI`809ywaKyk24zzk;Ib~2Cw6P~pI{(y-Nfwp+{a!j`wGT%A+wyZA!3BVvx9Tht*k#ZkdyipV3 z9s_C#fQRVThX8^jX>$Q)_N1&NAauwP*i9z=y|L@URpNjVz-ckl4slbF#12aAv8cE{ z8f(x}Hk7UNi)Ur4aow?pZkjCqr!l|z%K~Eu>-m&T^30&o*=NDl$Cvn)d5rbSX4%pUg)9XaIM z*`x#Kd&7PJP`^jDhZg&SoGcauqg+ua{kjARHxG*m*vJjm0(8?#{4EvkNv|Q;jVb|V ztBBU6(T@G-2sId?(x;g~p$Z)+!2}y@m#NSyiASujw`?Q+<}UnlR@9FR$-ituYdw)9 z$!>;y>4-bZW0+ck+Ivf8IB&fr-s4e`F^s&#y`eqdoyB~@vyQXqdcx98OeBl;YEcZ*)imKod+kOBSF;AWJBd~5*Kk7UoIN<3gY*={lwKZ} zKFQwo4Sr-2`y~`lGLWezJ(G~$a!6Bf)i=vRZUDwKq)F)BWD!Za`~H{#35eX@Usilf zigC5G=a}^7}W9S6zKWU@~Bch;`M2~Y$nb38v za|@J(Srmtijz(I2n+UQIl~F~f2X?JpzFzt~=jhw;b#ZYzn)w5Wc0+}EyqT55nkgAf zZ1KEOu}+Iw4{O~fpmpoDLmpwIt)7U~7fBM#Lo&`+jmd*Pc>Y|r#DMxiPAI&EAsdmN zhLw@_uuO{U6X4ocS^s#3{9j5;M+LA;$+?oiI_d=(RdC%fcw*x+6((){_o1B9)tt^; zFvSG!C4*50aDfpQCqcm-8~vwI!7R{43y*-vXb8mzkbE^LaQVujE%Bl)@gO5vaCuH6 zat>`DtR(sYVD85k(PZpGWA_3Hif{bt3!(lY@hJ63`i!>5M_<$1kIhZo!B&BL&jqYO zp#J&?6-z_EK0W>YpBwos^F}w{8jBx)$SnK<<~XgIgPfnz^trK5MSP_q{t0LXI6o7< z+y=F~O6-Ar6;Y~>K&{yAD?z@C>iYW=1LfB7Q%h`~kG$J~Lk(dEj21hzcJ28!r$ufQ;MSLyjH^X10F4rNZXYsuqyqAbq7Hi}A5`qzx=`h!%Scyb z6dWfxWd8W|U%VWN#^M}9npA80eKV}w$eqGlhl;7TZOgi8(u4M0nYZ=_E)8!VZ5R`9 z>jNMA9u3RvJ7~G0zv;>MUH|r#t!PiEENk}p3GlL34^mb1ie`J)sMxlIAA!-m_m7(6 zEjSPdMB39NAcn{{!1V;?5(&fto@7-=jDwgnIaUSPJo9ABwgUgd@!`-%ir?k3YP$~}*8B1>j9 ziDh_?BWy;gvr5RodF~T2%DjkEpekgagi-FiTLrgxw5eK3V9!TJiSrJick!&E$!22I zPs`1MNVDK%!JeN_jyLY=nfHO6Nn2pvkXd?JHLu{O_PC2gk?7HIn*t+7P~NT&rNC1* zsw%!VgASoY(Q1^nf(%Mg#6*lLzTWbvPwd$*QR}uJ%B!wdeA|V#?4Qm;|D`P(x%_?K_kR~eMX}!bZc@J0 zQ%~EWVH0WporO*R+(9ht)JBuDsLdT;e-MN2O@b;>-3eDK*YUqL?2cG$02EJxdrH3{ z?IEKS6Fh*(VeyJQy0On-cHFE?5g0atTbVye-7AbL@$PAk$rjjIkDz@5CQD;9nLInV z7`xM`B?Z4!?1hmMlR0Q=ui7l7L*_uV#6@(!#4Zv?QXXj)%qS^-$#{k}#ESCn%ph&2 zMLT_vFqu}jUH+XciMT?q?9B6Xxhsm<))`2gD9qikbH*`d#f;#_60cXVpCA=^GT1x4 z`L=};Kn2zoDDz_Pc;T5+^2vvooc9^aa(num>glM}P4t={EUu%XhPY?*aSM_u4OtYr zzd891;_2-wGv&rguMqv*AGqpAolB!nzty`;2um1y3+62BgNRHxgFC35$5DXf=lqzU z+GZlxZ86S+t+%&*vk=8(SRzGo+7W-!Qi&_^cn8M&h?KBgc>dtoLUizLsm1PE&;fw` z@v~dhXmJ~W5yG;X8lX^Jw?B9)zlvLmTXN>ZwxP)TLcWcX2;O|2RMZ?`hpiTK4$nGO$)F%Eqkr53l>`I#v8!g3=9 zcWpKDTm}bpI3K&~C!OfVm;d&QK@DI-q+$`GNQ)zC+?SlmQaKEp+sUbd4He5L)_Q&C z-p;7hZho?lfb4}Te~ZbouSn^C0U-7T0vXtCqe!bA5a={5LVIaB*wwiZVhd*vZWus0 zpdidt;`%#|@|}w}?u*H1Vk`;~>*(ZoDtjF5RIW4I>hny939X;)v|i#*fA(h*xtLVB z4$qw)v)M;Bm*NyNP7^Gi8wG(`r0TWqlqg)a8p2vlq-Co??}5|1{4_3Xq2zQ#z1I8f zsiFip)i8G(L8mu8BbiUE`+xU8PIXI+ap@nj*o6YSlq#Dsr%EWxz8CfO6oiYJ>%IJL z911~}y8{_zpME7d9T4@q&qkE4_9Z&FGfP-4?)A-+mg_wXkAzU@ek|2By|vri27SX~ z_ZQIR6&+)9_WWL|_S^-wbr}_B$>$DRQgs`b_z#K7!kS+C;6xx}F$)nkAu_KM>Vo(g zWpPctR;_~1KHeGlY@@5oPbxM5FhHba1p;@j6r_lta(DKt?_B82R@`eAUd& z%&entW0t2c*5muefuQ_5s4S73^=R8ce`v&OK^MVOdVDCUDyU`jX+3V^dD)>~+}X{G zU$pN!Hy2wfwP0{K9~I}2)?duM$F2JvJ=TUpwgPIHg$_!k?;KBkZL-6ro+IQbJPl_+O8xeKgydAspMc*QQn_ zw+TmJZTOK^dm#j>MN>udk9R~E0?O0%h8 zVorMGhFRuq{*SS^Y0nS!X7lt;-y}*Kl!(|r%ElJKlxAx!7{}_0r)m+t+Bl99W-#7# zEWxrbZOVku0su{hWpr|d5V%r6_}tZHMMGVtSLUbr9j!Uk++c<(1nQ0Dwk(7d8?h_X z&&o-PHQ2seapL(^LdpZeX}~8h4*`Qud&x)Ni!g!4-*2AiTD9V^<)pm*2`Tk~?HS-i zvsx8FMT9_@{8!kG;t<*c>=sB_pD7J7IDY3V3yhdNiE;?2WU&$9225U|{1nH60;dU> zDiS=`b_Kn=6P#Nprxa=CXUcg(1gcPOJ9;LF589C{Y;mY4YB`Ihh?2nQK#bJ7$IFV{ z_t{4pe=HmdI$e6G%sP1CjCBQpihz&Y9{HPiXLIsta>J~lPrgH$5MD?=kOnF6@`srS z6+>Oe1H*b0JV*(DA)lvm^(TTYAhV^(L&dUCCK%BGhMMFY8Y+sT;As?GCMvfVkWuqRiYP6(fQ4QlR78nT(L(U39E=tzVhr+88MuQR zxRz6#sF5d=6>OE9-7_#x!Uz+}%lihx2Ke4A%vybk=~;=_?U4M6vpHI9k@mXxr!HO% zrY5R3fTP@~K2HmzY%Eo-FI9?7N(t_^I+<{^Qdt;<$q>WP2dOO^f$gXtGXrJ-9I$>2 ztU;paCYv0gQfiImu)v@xFkS@%x0zi%M!+$-!T4z~!?*q4S9M=}TK+b9?3_`_e6XHe(k-i*j z!+{Y1i_}6DD%^vSzyo0r44@IyPj*X7JE@2$mQEd4z3{XBaRDmBB+v5!pPRrOsxliV z&$|_x@~$lJ@fovrvU z{GbO@X2vPGXd1M`DfioBfB0gzG-94j$|YtO9$L*HvIXV2R7D6MJSdnI%11?tz%KZ5 zYl!~0@o?!Vv^MPoi-ihjDxx?rYZ8?qLPZrqt3HERoID$ci6ko`#0M~}&bA?vB@PTU z*kv)@Z6~RUe!N5wp(pktMXk~&5 zjp)Whgl7;z^CYNV1iGS6(BbK&bQYE^HAo#n zP_!)Mxv$Nr?}I$42TmR{CJ*TF3h$^0k;pH+j0s~W5-*^54Vavuo4KswjWCtPr(_-< zy_kw#vfyMQzjzrR9X}&R!|358T5(bj`oU-HrhyWP@kr4$_U6AOraOC^ni%`9UE|@g zm%m_xj1PG)OWwacvMDUJXxbyp8oP^j%e3}UwqMD{Y0L)N?QDsXH|d@YozEVU_JKgihUTcXsKL`;`# z89lv&F8xkriP+#YP^OHQXX4y`3K7MP;4}}MzKWvq0bQPSoQG|?{sg4UZ+sfjaU9-iX|2SF$x0Dit;9^O#RVT%+gO|Aku+4K&@s^D zCQ4Rt_;9yPJ(drJ+MGs~;Yef%82J5t`kn3b@J&({iY@hVulj>L5-O3$xWSMN@c9-n zRDjA7-r6-?B58hT0_-;NubH&YTjQ%w37pdHZQzFX>Ry{w4{UQ)psjk<@rE1so31|LZ|ykP#h#$W^_6VAZO5#83&IDwc=x!7xc$FnA$#WA~Q&OBWdBwq)r!nK1;1 z3RQuB%cOb!auMuHCu5@3>CpmIxH0BuyfoxAL|{jUvoKMj19d<-OAJJ*lzB#RZpWm@ zC_ouFT0utO#K;is^O&aJ{-A*&M)}h%@<@&%gNe!@W8xuQ)N ztaMh`qfc3T0h6698OPlF+la==t?6|#ll~=_kgyyuZ-K4L%X_Yxg`f&ld?k#|j zc-|^_cA}J!)NGs7SV3Ma5BU#rP@#l4v=AzGyu9|>2gPDh&D|o!;8;TM8M$AFJggLq zGa+Yl-m_tR>Hu<11Ly}S!g}8Khdl9;AcITgz71e7TOKR`^PJZGNCAgmSd&5+G=s?F zAS&e_RFVWOq@kbfMkfutTPl9Jl#JbHPfR5qf0&Q|b76k}`X8 z2D1r)opU|&-+SW)jpug_oDCHyVoa~X%}wENO7^g@5g{1#$1M1-wTCtD`WShqbb{!h z%n+lZshCIsstE^|`&Eh58w#k{Pr}ai62v1qVh;_JG^gzDDA}VOvHJ_Fm&oD1Gad{P zKU_7O4;-LFt5!jkoTzB!9}yh{aTv;yvwXX*Du4T@#by|hBra+q7&mVv_8p#mvOYe( ztWM)EvUb?J&(_kg1{iQ%$%{(D18DZCz^`U=k3Enas&BKDN0?wr_T!O2l!*CQJP`5D zOunA_F`JC(Th8e_2-1a{t(ge3h90X?xkR9p7%MKOArQqF{XZD-%s^%VU#^Kk*i1r? zPc;_c;^1TBt#RLPv8fxriiO(j(L71%SK(jB4%ctV07`zzlo>Sn)2R5Jvr5+7x)k;K z@G?uUH8s5shk8vEM8!^AI=zmBlUk79!4p&(#Sb%d6ktGWcCS4@SyqQU#U546os{{h z!MPpa)fWlE9>ujI+jbn^cJ0hIVGp|K=(aTD_iG=%7u|ois^R3D zjr_ly@4-Fb+@1m7C$cp4_qfsTJHBHAUL5&n1x6&KyjqWm>L1S$p*E_{E*eUaNaytP zCT4Fndn^UVcOa}W);WD`qC4EUEA_wV@)sBXR8bvN*u?B5+V z?mKjn3mYWp5=@ri%X$`uGYtk5LRcF3=gN{-|CPY=?3{@bfFwT@g9?@)%p0DZL2hhX zbQcaM;X4+coCTB430Og0;ptU>QJyWDdoTh}WDBJja$VApn`VFjKeRBC;O_ZBxxSm~ zlV;f#E`1?76sp}6!Ix8;8nz9>}u{7UpZDrtFuc&HZPTE78qkpDIP_SLcvj%=|WIm@A)Z|UOiv-phUWoC4l zmFx0ya*e1R#QJA>P%Rgk6X4`N5RE|cO3WQ-38NaWJPAV#qLM{w-Xb+;*%yC$uMT&J zG10?#O|bP9G|x?BL|4i2%h6x4 z|F^{9w?D!&2OGSUN|GjHF!`@n+n&x`=n?DMf0t0e!3w+lT0}s0hluGT6%g(H`;4Woq<~S8tUPE~?T-%_rI+|@Y#1-8>N7>k6>bCk;{F7=Yw*M5 zz~*Gm@tItQ^JtzjBu7DY=olhd+ z&&TX`?vT@x2Ix|^eWgwt#%f^1tt7PcR57ZWZ@PMP=K%fErd)i0S=n0RJO2QFFGwzp zh$K4hD;Q_>tGf1r58x!;E z&+bNt@q}HeT5UC}Fz3|rmkaF=Xzyu+qRy;Fu6-{}p&lERu}ZmAz2*^1m58zH7Ohyg zQCdT^JiF!hiinH7)%1Mutdv5u?-km)`Gk?BI9GC!D7C=(rf(*Hamev^;LAaQa(~oX zG?N{>*j8&(?XLAmPlImo$)r%VHZvrT19C(Z+`)X%b5r@!y_tZWQ?Es)U?k<60Xj4A z$@~${!CSVvY^i9ec+H8vA?VwIc+uxMt?E1Oz~`9*o5m3qwftP?Ppq8(Er=mS!`)QAkbZ9gIt?JN`)_1o$Ju6%o z@r&DGIkFGsM%Ce*IkNS*wp5+)F-uFY6qmL-{?uX}WnHj|!!@MmEhhXEJx;Qj=hVwu zxTXB3jEawJj-3_7p!_|*_+t6?;r+9=5$L5;=_%wg?a8G!%jX{t2|h8LwTfZ>@}zN# zjpeyhkK(FcHrvd~q7Jh#*ZAjTr!E!CyZ>lCC+5Y>djG?r z4Hx6&>a-r#ORZkYfFKdci!U{I@tK|sTG`??l@JNLEM}v4Vc#4PClYQG09df z06+ms=}U4U@#H-{$k=zFbam*}l#G;<%g@h!$Q6#0+chW? zr5s*b*)H#LU72?1!I4N1(=Q}v+%2fI;{aX7RW;Ymo_*BLIA|97dH+$pyqVoiPV z7}R808rm}hVECC@3@3gNovcEj$Wrem0x6-%sGKe028yS;~Yulek9j- zJfgdK#pb>@a~Lj-3EFyZ(~)?PzTkTd``XK@_0zxpiu>a+&$jool7-rLyXP^dW%kz- zSq-bc$lUJFYT3MFduzc<$ytNVSd}t;oTz8iJH9)F3^ZLM=dpQaw>KXznZIvkox@xogtG9LOkGT{l_046!KfEWFnhk4YXk;`|&%gGzPlrUQl8Nis z$$9|)%KlEFhxIc;>@spysL;26CB6~kA*v+_=EuecKp321CuOq$r4Z@*D74co8et}} zw4>ikz$LAnFTek&r=$7LMDwb8x#QrazM$@@y%i=sqXP3sG-cLq)BdB*e=YoDTL0%v zd8f;N?>+xF{cOWbg`n-%f2A-!A|tO#!3&~1e}J}OGY11WQ7$-M!rd4e6FhOcf6=Y} z@JU`kA&*5w3DRUJvy4(_BiC6)a~_jAjY&TfS6sTa_?*p> zSESWThj+Yr!xTFO;Z<9RvxtnyV}2Y`#290qYO?&?0%L@ALbya`@NYvw!aS z!D-n3`OvIBuT(wL_PK}EN;YNX(C_13wkt!ORc%{#<{e^eja`In`_a8 z(nv^7?|A;Si?b`Zm0AdYRHV8o_K1cHggF1c8NB`Q@Sl0{W@-*p_}uka{2u=Bm9hBi z3Bz+|M;UBT-*G6%FWgWZ=+dpLQE^fLXr0sWc~ZbvLk~ z*Q-G`yAtU-!R;a;%`h^DYtlv-=h|nS)yXB9AELQcccFXDJ*qEMl{g2jX1rIwPhT-L zrvB|B2PpJFj+C`D=vWO1htBF1^inyw_e~?uC7bU$>4@4Kwb`^nXYVBU9nkuvYX9*d z{yKDh9o!~_uDqsY_0WA|K-3IOo8|y03|rG;ky1uX#V4?$Zws-_2wf@qBw*hQ78GeQWZ2Gd}cV{N77O2a?&nUNxf z=hMIZ2~5@}CvSp?MQnvkP3-hYGP#00vaHK3-kzqQ19A4IShFujLp&6Iy#P5sKJbj9 zhNhnr1FPlWy-&M>@0p?h~20+}3NcPJXOAEO3X3y#;=a}X?O&$`1n0N|HtExRCG zH;}ZdZzl01yo~an5A&7||Mu?4;NUSWLRB3r|vtu9`R z^?^kUCJDaarlWt>pywu{#%a2rLL1zPnU8)?I}hw+=@BY9?5C&xK$Gty?=f*fSUos zw2ZAS7~gRa&KQI?NF7bO9n(5G+sp+AvTMPiU5q`Ca(C^92hUvJzguQwvVy!7LqCoe z&3O9F6|j?Rk8AJYIs{cZc5#r|Dj$0GB8aj~Pzk1cyUIYaM06nFh}()%(C1;lLa;b z_$Foc8hA=Kl=2D~etm_uX+0 zuAZ6g-1p7tc_{G8%@w5K;vL=ZG&gWUi}W1{@%9=H+!N?m4i1&+*;>e=<206>)J=(8 zaG6Ho>zuNs)?c*dD7aQuM*(yWeNqZXYQdH|L?J?+Ep-~vncd?$CFsJ=B4+jJY$v&P z282_x&RWQ|)~-X7#vE~3GZZN1~KzSIc~eDg<1 z02GIe&4+HYW}w?eZs;y=@79z zY#JAwel|p#G4&JY8#it|viZkcIBl$_vq6Tin#Gp0nw~kuTI_6FI7NwMkR$2STn6fX zvh@;1k%mfFLw{pEeQ2=@&Wdfp{3K#2#$V^JoA2?9DZ@bm+IUZJjY5; z`n0rOP42AKxzx&8_2ZRKFU;v}x!3P0?awv5y#yrr&AC0k=YN0g*?$gtFz&Wt+&gaL zuB)C!NN&oeV-1QN%& zP7}0QV{kf7I`^i|p>65ho7}2XBBFVu`>$iwucS_W2pa>!W;3_SXN39d*q0nIBz+{< z5b0If%WROnKT;MKbb@?0YEbfH$v}c5XQByLgPa5i3hh7sHx_sn&%BESz$h)UN6M}Q z0xI=54VONj_id%iqCeY0s2|;BnNx%_#u)wQxx4_P;opE@29alyP#MhXPq^k~e>h!9 z+QDGTmdQ%om{Jho!^4P(#ZbD_z4ubY_(v2$jdGl$|f?3A?> zdG`$YOyA7>Yv}PM!`@j#&Zw=6^-dk!qY2_&bqGRdM@LhvSsT>Y!99OcH+o;^Ktq(B zgdT6JAhbbFBi!oy`zvnhM)z!@%$GV)J;ZZXoFC>o3qrEmjf5nRkOQpBQ-OQ+(==A3 z_XNVV3F+6_8!X{@Gbv7;y4QaooEoyR^ScFEhJ6U#0+ zFK(LrQ-3J*z?-WpdS2;IINsj7a(T~};gh!-TRm{(9~@zAUPhk5F*w!OQ?uzW^SQ{r z)_uvJ)F5USTyU?bNY;qS$@~*;)rlsxYc=rEkj3MEpk}M0tS-A+#(V+_RP(>^{j0&R zlSp&+q#?^PhMEyww^sOMDke2G*4<`JN80O!DXHAznv|jhv^xq}_$e0l%viaj9RylZ z?h6j>jE5y4ospXh3~D)Pa7={65Lui>B&vvm3RGC;R$5XDvPf*o<)!yIR`|ZE1+PsX z?{S{PMTMLXI9MhFEu&NsF3}U@D8h+kYIaa`zkT+;MwqRsE>K>2Ubr8VlGjqxe<JblneV_MOB! zYwU_`h~Yw%bqtIC(kh{_5N&OKuT)*`8X|QC93)0X=nJMBkpdfW5I{PON3^=gn33mX z2|e^RT~|)7oqe?zS&0s4$E%5u8CC3&B+(V4?IvjCB#h3Y zn#=CI{*yc3sjts|X;49QOy!cCI{J;r{&lnJuDx;CJbc;X@bu;FF1t6MnMp$*{`LH^ z%V83m^OE5wj&v9{b_F|mThcc6(=Sfcc*ULoJy{R3nPQJ~{%WAs3qbS=_GVNf5c+sg zn^j^dr-3EaaTF}_P-J#P%hG!d%8tTKCsUeFrgRdPY&&RuYVuc1>b%OBQ@v9@8&Ibl zyXKBt?a@cU`hRKjTTIt=FxuJnl<-xr`BvFiM;~X)vgg1;yP6&+gF;U(FVy$F36{w! zxkS9ZlHE&|uMI=ytkzT$-rtxH98sfYy|qIMOiETyGulEBha9)!Q#1g2H6a#6#v(0i zf&C2D13VLwSvLV_hRKbYNFK0Z11xi#{&(^e6gnT8LH*c>VDYs9^Uvk`W=zGUt-g5Y z@_az$$x`A88I_jv0c7q}{BK_(k$?E^Z{#mcTR5MG3x&nlcvqgQ!qU8RiVL!gt&e^6 z9w8_2Xb^`A>(3~)tfhxYaaEzC0$ljWqF`ym-F)K95*~S+4E16q9df4vr{j`|vc(gs zx>^1H-4axA`Wx2q%h@Y8r!~2j#60Xu+^A}8ZF)Bu5)wDOE?@5a5jerW)l6C^akY4O zhiua+LfdFXJc%&MxuY37g)Vb=OiudIDJ`E^P|EcyRGIZ0FG-xxMHPw^UYHDmV^uHZ zL?Omoka%tU7}Dmo>q*ra9jSSE34f-O_U_-|jELqlGRH|ZCRFC#wvKY4Xw0Z|oIdFp zs(op`jev6cMG=N}8%REr`xe6V(*Uj-T>t4QDu|PYX^RC(;y4!$6*uK{Dy{UDQ+W7E z-@aX7!2$itld5TwsMzvid~&9pbTle0%wMRpOzm2KY6a`5ZE3cf`MR|}owV7O^ZRoc zg$E1!7#B|xGjY7l$9ycydN{-^mhkX3tMfrx2+H8|%#w_^tePMciK9)jf{-x_@HGfl z8Q^Sy&K~}bisx?NJE@3->NGcUGXJIKw@%0!Qvb)3S8~T;#-6t_;k9Bg0a!ZE%!X8IH_sTs7^QEh{S< zmwx>I;~$;}&N;V##m2Dp2QjQJ9y?+=I|mp=7{3$>1LX&@055TX?Q5@; zO-Hr$ISk8_<98g{qKDgJW4Uw;2tPF!)^RDY=%0cy^m*Q{UaSTeIf?M<*J<*%%F#0_yg_)Bz69KhJWEUpf7{pJ{ zTiU!HyOomo3ePa<{Fy&oNNCwCOixSs{Z3=mz~ng9eNdKSiD7WvMMO=BGli&RP?Z6W zR~{rn483TVAq1BE=mH=_QskvHda6={xc)gstBOm9qH`{24%74N#Srry7Domu8J6#; zo#{7#4>OKksgI+)TGaT*Uv%tH`mEt6F0JaZQ(b-ZlI+aY>n%1t9Osc|tvX9$U3IN^ zJ<7q$45&>_Hw8HZ;w=6CS3-m6cNy+Ex661Q@aR%VYZH!wG^#3IJbO%qAE~l*wgR}f zFv>3m>kG<`)!yO_j=47<@fQS0l$%_$KEO20=}A`9A!jP>vYE@94wY%zrH1Q{`<}K7f`!rg|hGgEI8s?pk_X0c#_UXBox2WUZ#4$g8<;7 z2C=yIo)S=r$$|H>b*u)=^kwt;LAr)o$8$;+zIt-y$ld#S1T}3ABg#G{R_8G>+M(vr zi4^tpfu83yuJ(&zg;Vvr0z>HE@1D5_)bO>p_iI=jgpT1>dnEQU3P#j{jDf zr+$8azEa{k_q*lsogaOHsNkAKoS)!C?r`=zQWxw0c+qn9Ljtw15YMrBa=`zRr35;B z&msg+AqQCa0k!p{UHxor(E`5f(JxOyCsojCjCmquYx7W$K7(1I#EZu8vPY?Le1vla z+O!3HKe8*LK%G<4*?39VC^L|F@?Qu+Cj4B5bHyrk^B({tfY0lfrbVfaq|2MAsrK^r z4~sR+N6pG{Q>o`U9E=iB%WkKZqluQANr+ENh@T0b0=4H$!ndQ9zYqt+=2s7|^nW9U zow&e_K`Z~@^)pSvsOVyD%M#Cb`6p9uS;LR9AYDa}FuH}(om$WbtVyb>Vm?HcVo}Ah zoFB9Ru#`7^H0KEAjxgY~`xX2=#61T0t8*XC8(3E=p{BtpK-#4MK(*&=aZghwqJ8Vqjq` zB`RG>7R?_a5cGDZLgXcTEcrnb8$$#1jgeLE3!p?g*m@9ZJ!+RR0dRZ`!AP{2V6Er# zjnQv0N{o=iSISBYn5EasNJ(y@5r$`g#~LX0FW~l)e_q>_MVA;~Wc=)4{0hnV*;V$F zDlu;RXws&v-a*b134Z+bUhTx?9|Mf6E`d(;>V$fl^tYeX9oek@aAHg{(YQbzEx<1o z%u#K0ss2>%Md3hUrw@5{ppcsRUMq1yD?oz(On}}=AqUaPqlzlMk|!q~;>RaDsSI2P z+1+^o=gH`7BtOx^s708mu^!_47qs@SW(Fi+JQ?_i1Z|;$^cgK?;OOr z9>Q8Kcw^_awBD^!Pw$y5CY&#Q01LdtZxdrM?sv)X?wW4Z^=5g3_p@{tjE2{|WzN9u-B}hY=g_NAyb! zLW+SQ%@Q3x8BZesXi3nGQq1K1GamqOsr-%}LFrpFCQ2sEvzBPmL$GbY&@na&K%!Ft zc$!k&BtSo+$my0rg(Dwq&H;80+u`N+`X!)xtd2Dh&l5EALNbwzP|%n%MuJ64P(R;b zAQEK=J5Wxy3!BG90oD7Je@6smi7y%N2{L0$;zKUuMGW;0hITJd`#Dg%ml61@><3+- zW}J(o1J(92f_*=$+4*M;tZMm4h+jK&q~w;@Y(0L^1P-N-Ls`}b>8>y)@O{ueGW7t{ zI^2MGI+r?|pjP}|E&D6=Muz%lF|-bV2qDZw0}#>k=z&$-wT{klEYxP|*R4yCPZ_vt zyFw%#TU;daO9NvmiT*;z{>;G1erNu=b#RZjKt3q%?EMPa@=#gRgle&j`4$>%mlvFC zfZrfsJX;n#N3?NdfD#rYlFF|F!0;Ua4FGB&;~h&EH=TKPMzHMMO+Q)b;;>2U#8gO4 zu-h-@){9@_j)Sj)(lO-HW!g-fyTbOvw(21F0Q(SHB_F;)q;1+ZLMWdy*MpDAE<762!(rFdn6B-lX) z{)V?^+^CigRAv7K3kGRaqANgm4 z+n)T|D=-eNz(onv=K65g&f@G)YIdj-@JjqK)NJl&oKVs-{k?^Ng6Nk-@q%<8?7Bn} zb@BW!ts^fHm})+PQ<7x0MxZVX*sg9o)98Td7aSk5>?)<;Ypat>Slab05 zwJPg?4o!y2<<-tE_Q@{66tg5e1=tNbq%Du*y0gFpJ~V*t8@hQ5K#Y=6?sCIT@Go(P z!$-6*45wDf+Il-?iGM^h=Qy@saswGz2OhWdGj!II zTPlJcjhCGp5Xe8D%!51;E@%y~(Q?YRQcYwifZZ;@b+b>5dUbWD2~T9H{l-#3yiNd7 z&G``4bfNt-3oEU5w_MPoiL6?(nr>d|{COES{S{}I=dx|&;?_KuSOgskc6mm0Q6K~D z`Ij~2kL|EaCua+`wb$R8>VHtVT-s7x`o;O}TprocWnx42Dz)FuWxOcGasO1!<@m6o zEn81N8E`W}_qrS?c}wd}5z1#b%C9Ru-)s?|!0Ed`LJ{9U=%m{qKa0;>P&5*1OuW|E zdOKHH3n{-5AoZh2iN5*XzEzN3iFPw1%5=49l#K!rl8ve$l%sM6H)^~4HhCRSJvK>) z3+eKv?^$eo-imroKozqu$sT!rCaAS8z;=NzOah1&`_~iz2`qmE3*czzi#EWw0{y>-hVnAm_R;t?5A=Qyv~z}L9vOoW z4FAJNWn+&44@LcLxp$W&v!*P{R&#T^%;AK^vAt$B!75Rk|9rM;`O3L@d1_Dt>T+yP zS;C&pnmL;XaIOMXXNl_CXU^6`C)&xKT*}EjH6yE2{?~d7fqe^J_V1~bpUqpRApR1x>%h)6LMy}=y+%UAfXeDOk68+$ANJX9 z`x32j^I~mfPo+fb-5g4=5Id+eg#<^=KaOT8%jAx9iB7Z+UMTp|wxD(Au_uIhG2f>< znSC-&G@h2LGMa@8VmE6ECbX~Y?05B8(+$oIZ2mv`5Q=-HY!f`|`K^_hj?dryknz&z#-P zS1vZZ=zTB!k9_>Uq1oB5H2fVz7^`Jq5}X180*NXFiZhb{VuPX37rdg^4|o54;APWJ zMKhHAkQv&Y$p#xg)P|VhE|M_&8-@!OZk$wf`hWsJHmhjjU^cTv#y!4?2kc@SF<~&* zTM=?$3@fu_uFt-qBu@*2g#jhU@>i4ELukPQ_fTZDL)7l;t(JYr(8!=m;=NbCrB=^` zrd+yr5qfs6FLQP&qv-nC`o2pCmM0hU)7JK-9B==x#q;fEn>22HBC|*EIUk19giwdP zhLMhWJcIhzes%TMZRqhZV@}BjkOm6 zCEVAt#4rA*4O@jJE|OZ)hxSu7F@;--r{!fahk!4BmR((oXLd&gUQn})x|isqmox1= zMyOV@L+EjqCYp&&fmeg=B3|eE*^QHO{G1Q!G^yH+zgsFcFbu%}!P-P38>lLer-Bu6 zB6%{kScKBa`D)fJ!yS2VNc@yin3-# zv;6n1L|;LOHP_-#H3{1s41J_UoJu7vjJ?FIY|nT_q8xw zFoQP_Gq{7xm@ASxgJbQBons#)ZVnjoI*_}vpG(NYE#~X6MG_A2+3!A|46Q>R59MXO z{tLUOH?m#~KeEq%5k}kQfePJ6@U zxE!w}Bzo!Yo(_cTo_{v6Q@2sS25v-s&Av+jO$xkpkIe^e)yteVFp%k`WFtI~o1#|) z1)GdB?!M(la1e%DZml%gzovdES)VK(4f-mr|5m)^zIMiLgVW}pJqbsQV@~9zL@e<2 zPK#dZmX;c2-z6+_jRM+A4G)bPG`%;DvQOe;Ybisj9@YSGMw7TfiO*IbiUCRllAJ%y zR_F6k%Y-TqpRA6A79af8RZoXRHI$bi-*@+__a6jR9C{CF*zptB>TIHdxm6fMh?{AN zF+qhL1?xPR6r4Bz+Quby23dG3@DH1q@3pZyJsxDZ^ROB!L-PVuGf52AA#&uLo6w$B zg)@on2#6e%fhGVU5g3@Ghyg_yNRV#B`Ra1cdQzghOlO4--R2yRZCinu^&Tn=Pi?{P zwl!X8a4+;kgSJ;Pea!63^|)#o)@HvHEky#&VQG`?OndEKbgb_6rFff@_aaUIyjHtn zzF&>sJo4-T=}TTb3v+e98<-IZQ~*XBXTI67`cuINIs_xeXSIJ+B6hIUah5A z%1cXIiR;Zr`QQ(9DFN{{de&|f0Bc9Gq`9IB}7xWn~O8pINP zIv$@X;F2*-DZAcB5eCn2pV>R8FvT9Z6_eYwZ5K(W`3fzAjN2Jiu{sM8DB!sXkUJ&E z%-U@jv^d@B^7L^xk^lhUc^p(TS?2|&N>HVF$1lw?boc6cyO~1wKPU{|-tDw7qD5lx zO-lPqF8?$b?ApDpGwk}$d>?!TW%2_uitw&FX=s9CyuGs^$%GN*loc874Y?ebN>(6( zgH@UtPzKXO)9o|N3F3}EK=aVPW6Fi2nwlUn97-s%CwHt3cb<}WxGt#V`;7wCH$Y1NpxHDcpYX_w2skF@IlVX53o$uXkvivQ5=_O?+-{Vj1_tM}J~+ z)Aq~gyIp}}4Ov0A{1nDoE@_^6eJ76H}}6gjSBUk4}aH4RH)7lG)QTRFPPXW{Hd&|$?2 z^M=QyV9n89=qbI%f+i`%K}c@fF*2kZ{p4z8#vm?W^~o)4&!y$O5)2ox&yYV>zr*e8 z#6LCg#8bnx3IU1|8$jWj?s=z1&4^8>2Hcf{`f(P-t4;cc0noM~o zfB-d@R~9WsEYPmLVHNLjXBz}cnzg5WBi#M_c7jn-*XDxWRf9{J4j&%F`n8*Oe!o(; zf|}{S{-WtzMR(*crjm^o+w-8mf1uDLE2twtU+c`tpkP{zzX&#RK{~JyI(^A zH2ul8OiOk)Z&z`2nC;Rk#CdGschTLjPYtWlMq?Q&Sj!MWA;joFGY|Y9Rf|zrKKHL@ z)Pmia6bj*$?#Bm~Ss1-V8r(L8{NVVoMSxM^%FEqklplbjCQ||vdI8y5myHS{<`dL( z78DLfc;J+7dbx~}-D76jwHMcN%_wAK3Jol1+qYt#Jv_i3MLGYVX>aH6z|cx84M_fv zeScl@1FxT#Pp!p?Iv|*orJPV5=E^|am?0f-H#@XQ8oC330)Q4$0}DokY*XBBooUQt zrqxzk)Z;TsL_1QcyML)2>hvIHJSM8Sn$i*)wXBYQEaH5)6GhxO!2X(w4n$pTidq1e z+ih@yt2moRb5zU2RH;@|JcrpsxsOR|Z6sn619ee~wEd!_oXbeQ&^j5IT`16aPEy?dwxi8R3|H(exO4RN3L$Ye8nYCQ-e)Som;tUIYCK1)b zLMLN@T4I1M(an%Ru93j42y%LAfRLdiv_;*&kkVY9RGJ|}{6pD(XBp;O35(&|%&-6@ z1?aIFwDAyXYyiDVvKuW&*M=)n4j=6jWjlj4^kOx-==x8jMhfzIF;jPYjK%^L)kQEa z79rY2MqBNIiln7ZP4FBs6JAl4$qNc4!9!@!0?)vdHG*hMAJeLtJ&teMLi~F~Ia706H@q z<+3lACszjnCHAVJ$uOFEC*y06n9`e z246eE8UM`b9pq$1p0zVK_wL2E3Q(JbHF9qEBn>4e`juGJntKuFSg4X8$T|z+C239c zg8pL+;gOVznW82O_x#L6&H&VQGwM{g*mf!CMG1GACs5_XKyd%*j-^=Oor2K`P5q&Q z$AM1&Mrag8D9)B^%Ce6B`FiNMfv?c`Xu&Azs&`&Op@s_bJ2LCI z>q^h~s4wvJM|>3C&TW(Y`N{;nzYH*w!o-h3Y4n%^x`NG}ka60P=dY^z&m^_M0-)j1 zR|W7G8hQ+9_ne289mqjXvl{Ze?A|ivCtO|2QsZfa#zNvzO`nnsLGgVq`VF(mys&gS z8(rXqnnq-Hkc`iHX>>BxK$EWfTHwkR@FN%l(&J#rQ8=8lAB#1(S`f_ahg~>S4t`U1 zur&b<3JfGcy)aO3Ictpv*+tl6|K{Y2e<8c*P%i>B6aamZ0DbWt;!UpzAV8CUmWA@6 zyLl%kz2sav=#-h909lzPYCAyG4&zjDI#SNiU|f-^@741^Ujgte;GxSvJiQugp9)<^AReZ%z3juli%EhqP0yPdEd)q{Nt zlgan}D49Pvk3K8CB%-<;+_Mn1j|jB|jB^RXdhiABdW2%2pw=t_`8JfAx1{FpOzb5P z!L7CV(`vK=HS@Ooj-#5d#?)fSigiipgd>_7P8ohHK0lis$&-i9n%Qa2@t+`c9}4yn z{?&M3U^kSm`E2bHm%T4NFk9iRt!IJn3kCHkvf9Z61Zl+e0Wju5JZobvkvRP@8pj2VpwFOZ$@49WZ3|@OuTOlx% zc#=kgYCRoDAwk1L1EIXu7vZgc%%MeZTSMv4xSx4z>DnE4 z(_D8c69L3=A}S(q7_)~ z&ms?QC45N?WIfxjtM3d-douD^TvTou4f;s8qP@%7-92yPir7K z)XJoE_{G}CYr3x-Zs$x>*XsA0wPn0*A=1+I&fwH$JG#$VWaL5AXISc&NO`McL-|V+ zlmf9!aJTt3#VaHvPAOxn*HHa>x@+Yj=r5#WLVD-vz59TL7RAWFY>CgLwE-K*YF7Qb z&RdhNdoQj{EFP@`zq-f`tgYwT^4O@?Z`Jn7HgwJ5YI`G4&zb3G+fk`Pp8=P~+FZ?= zqm#c{K!rr~vw4^~1_T2DAhc9vemEQe_=^Tht%Ak#5hnd^orh-&Fo=yep4%)!Po9O1 zeudqyhNTf;IZUK(sKObR9CoXqzx?zN4StpdKg)kwO{!zXDdcCmaP5&(L}F?LwKUc# zC06~xf?DeJ`l1QjkaXSS2#x26qtS}S1DVgiM`?745&x_hjXsSFgvGmGzV{+fuOgP9 zbQy8zF#LCRZ)37Saa0899untGwWC(MLlI@4w6)|LF4qW z*ImlLF^emHXamCo)Ays)vUj(1n2$Y6OMFb3ypr<|D=hFH*4n^uj)qQnQA6&kuQ_d- zJ3R5aW!EdV)bfdw`F9P&yXpZoTmBuZ+e=0U`V-A^Y)w#k;7TMt8;L!0izgA!iBNmX zb@n>jy%(Xru~52Zc5dN1pn6BGyIcxfl-iRK_~Q247TqtMd$2E%9~BipZdG{epm0u% zyTo)ZvAs@L01R}+J8w^mL%GO${`^BNjGl8zyb4sgP1Vx&Dj%Tm^Cz*Ap+ew@Q$$tOO?V3 zd58U&(?1;IL-Si8i)4X)SM-5)NMMjcb9f-W&sbyRLMA4}t&mVm9z}YVV$1r9 z!awzik%r0Hy5OU*{P;A9+PS;^cebg&ZPBa4xC9uCE%zWwrar^$65#IeJz}Qyt}hLA zf_2#!CBrYrq)_jkze%_MhE0~bl)dQ4DLZ-n@6!OIwqv_HnmQCq5%DxrmF=jCAP|0r zq%9)?(U9?f-jU3Sn+Iv=n}#`or^gNc8vnOp{IaPE(9D5Z+c6=%;R1{d+F}*xjdmT^ zxq7lOyky$=KczbV+d}Vrl@rp?i}{GO?t1RtMB7ObZ!hOkUD(UIm7Qv{Yr1dlPwp#s z(pMcCA1$abvol^THn>ncLsVY*zZDMkg^ykpp;qb2&uHk~pO*KP))m@nCi&;h_3h0o zRe!szkk_bRvVQ9$o7l4UreMo$|G*W$9{pO6+pqfcO^zV{*eURB-qg$9GkLphUka}xDF|P=#)m1S1Ky^Hpp6fWbJhsyq>mf2lLrbD)A>+<0o13#L6)*K zz?RD2SD`3PE>p#{MB9mNZR(SYD7Q_VYToJbx;U`O9j6BMnQkHz3QvGk442yc@s3Gj zRXvfW4;ow{kB5Pl3FA6^O3UyM1zb{tdCF@w- zX0C}uF$)X^!;zZiq(9rCYL<>VhGq@v5T!OTv#I@#_Ccghn@+|Co2P$j!=SApxm9jQ zXIS488GX9;Z2US?Z>c~pt-~d+Dv%~=tqY{K->FO3YmkQvFtu2&3g6q_ikl9Y=&92@ ziVN~g=}HO|83le)tPXEfDfSMcMYh}t?wWvA8(m~%h=PSzEAAViuH@r0r3?A>Dg949 z>(j;?5*o7?ZvZgeaquGq+rQNA3mjt0ufPp;yYU z;M*v9Rl&|Zy{bX3wauyluAkN?cWu$#IgN8MX~9N#o-V5Jw!A89u83rJ89(tc@2ST6 ztnzc)oXstgi)C96b z-;tE(H(Kqg0|`^sY5`Q=k;<#f2kg~6T$aO)^Z*WhT2MFb5-# z{&xI_H5EUwSQwnvGt;^dee<0t*z|fSzq}QriTj4Rp2&AG;Gi_k>7bvc+=HOGHz@>& zhO4=;7^pa%F9E`%i8=)r(l_dvx3mg$POb|?4e8s*m~}p%S31#AV-^Xxv^e{CPlv0( z^*W=Tp4Pf>nxN(6&Atp|4dqd*^CWu-;|_4oHkcA0^LPH zekQ3QmhLo851VERWBs3#?z(NtE&+6NP9UNy&Io}mGtSTtH6aRKnGm~Y4l?5*u*guqiZ}gdU?5a$&$$U=%FjqfIN#h14K;VP+*nG8GUY^ zRn`NL?%NuOIl@k12WC{Y6#%_RA}j97<`CL=`PM?Y$tw|f>l+QGAtnLk6>>ztJV!N^ z4o0z9ARV3$6d*?=5+n*yG@{c002-e-^nVBV!5TA6s@Pk5e?~rWn6ZBnIjH&?kZb;X z!eX)sgsb7*bqR?_hj2Y~BW5a{r3UB_5lB@OiE&v;Mg24Ap_d86S$6fnohw0VGX!Xt zpTK%Ovyjw5%(Ik*FbE-B5U!DyXKrku5Gpnu=;cFQ?WcFshCDpJ&64HY)pPbxcwr$n8zUI6A_^Oe#KZDn~=jb_0G$ z9$9JeI5<85v$S^S<6z*8tLnnDu)!2|1-qE1!^vEj-fev@gBVM@(r zqL3A$_~dSLB;DvQFQ4GrUIXzqHPd|m;vl+dj>e!HjVN!$xK5^{EcZzD5|2&Qzy4&Z zgp?)k(y_+b2+2jag@!&MSS2TtRqH=92Yt_e)jwFDZDHf0pmtH{sd|3FNMGQtJ)1c4 zJb75*!(n&Tmky83-w_J~^L$R9U(L1Z*s5S584IgU42_Onw%TVwR*WXa5>D}Rt<8bu zS-G4mtvhXOs^sn=eysXs0N`t;h5iko0Fk79lpbUYT1KdPKHusQz}{A2i*n#VWP%x`R| z-K}sxtVs$c*aOQB?%!xyCEyHuZ&H;cS|9aHZq1U7)aE4%+WHI$a5!3?k6?cvr_lU$|B!(h}{0{ zO9!B40Z1w9=mq7zLS-0@B?hVF_#f35bFOBGlj$S5*tXRi|KnBLB@N$22+e zR0_iwDxodm2V~~(OeQ9a2LFo&KgN}(K;S-01;@-){HBNfnhU%`;28sz8Hf>*3*|s| ze)?9IO=SLUp7YRb{@{X3=!#22(M}~~!%=h9m!EA!0_^9bd75_mFI}%`rB~c*!j5k! zu3kRhRG)X7em;`ljdwxKFmbD{O+CByJFeFUa1GZD)qW8<+K0*m^9^{yLG_EXSB;Y@ykAnRnBa>TDl8)v0PH8#tT1=@qq(rtX1*C~# z$<(Pi>hi3+mEE7jPE?{D{uQe~8K?#T>SufCb2+A@UQ+-l(g&zM4?xZWP*gbr@wSC# zukI*N_61nD4%A-(ij+VE1W1px!)%TH_4=soTj0eLw{E1p&}@v@(mqQ(K8G|jzoI9#zakX}#u~Go6>az&>jB~O zbT{AC3386e`HKT~Ws#@4Kvwe(9G)2?$6vLJy!Th1#~<%X~bD-@acP^)w=2S9Zy+fpIq?OCehqJu3qo@$LMbZ33bL z)Z{Fe?3`2@bm&ukg8g}U0PVRY4vohpOJP|&k3$IS3?U|q2u~x*>0>m{Tng0~PW@f5 zYk1iEIVEcYj%Y8~RgCdJK||Q6ctds`7_E?Cis8v!Dg3;1i~PDl!U2qDvt^ZOU9r(8 z{b}(4W)HS~3n7J6()`K>)1nlz!XKet6Ps6FeWSFC-)-=^IDlzk6sN(7r?o(&c3k}thdJY3kDj?m>^Set?$;}^L)^s`$&&NH?t;! z^+`arHQ;u74o+TTrrt+kfGxP~ZTGgA(#|$dUooNqTjM~+gb-5}8<)t@lLGbE{s^GSXvzjkJlU$7HefI9O#h>D=JogjBSIkpz8U}e1j4Vyp*wHo^p zX768~EktkWReXN`qO`X-4u1?OaV7|jG=+afJ!~~cD8I9J7J*beKvsNm3F_fgsyV$A ztWELPOcidbaR{3N>z}~C!Su`AHXi<)UFya~&r^z*e&>!el~#I-BY4{S3y`o4NT@U~ zL<$Xao(qY9+^6M{nJlzCt|0}?v6X*2cZS!oMDB8#HYj@_KU18ans`1^2u-_CoFay0 z(i?`fsw5WL&34xE@xV^w0>olI*oNjkEcI@tKu=12r-ygBNpF=~l(j^x@!+v5)YSCm=DOS*t+mYN1^uJ8xYZjK<00VI*s+oz_HU_)r24U+9 z7w>Tm48NY&THF+}hBHg5s;t~p`c~gY(0$LnBp?x*HVb@mVN}7&_6Zo|4!!Sq z;jw!>NLLEf8x{Hc+Uo*B{Nk7fqqbsI$d6O&#~jP)rmeALRji_5Sz@A%>yv= zVK?60KIK1VP6ZRFIeG*4+`mpkM08ym5G|hGp9WCS3rjv3)_<9yLkjCJ3?s0@0#Hww z@7W1izB;qO1e2KT>M+9(V+k+DzITQ>r7VsZT@e!1ZB`VZuRtIxQB{u{_g%Rtr9_0SB~HP*vQ@fg;2Ay7ZLB&p6bY~ zAHH}8Iy&c|8P})zPNI_<`G(J#qC~2$_q}1Fx6OB#F6J)GqC6FXJWp$%FVX%tqBv3$ z@;$q(!vS!_W9V1S-wN3t8i|}MFY>OW?GBsG+Y_O4r9JOkdtM|h@0&}1h_pY2j*6lc zk1&fbc^d@9N#zm4p}gmb-*zL|!w3d#)P@E%DLBfb9Aol)jSG%(hZUhCX%Unx=i!YX zic!=&oMkZ%do5H7Q;JHsq_B6f3A!uIseAEyO-J^5*TrE`X=e;C^R7J0=Ny8RX) zo3tuSYO&MeDqXU^yh(3(EYIjNEXV??S9^VZ72+*a>AwKkCJpkTRtE_ovZ`x$BkzIv z!W=~%JoBob!j((TEDqxS=tV>i^9|4=)~}n+xmXqIuohxAJGR}yR{3qOcPz*6IOo)V zQ|o+%?OI7qsgkO)6I8kb64A_ zSo86&S1&)bQ9Oi`ad@Vo;ZuN}yX>Fee6iaZ z2HhAwC4B~SkkU@@hD(Ms2+`hdQu`l-L6T*npY{}}IEV;U|EZ|jx9bZVcW%<+gT(a@g*J&A&?=6iY^y_~K-|PvkF!8nO2x_Lwf>bH zr#<#}rh849oMZkVv$g4>cRTK;`YEThK6`zeFt52|cDkciL46G<|Cwka@SI@!zpQPd zcRmC@5bt~?6UuQS&SN8hX`w*9L@TuoU>IZ2cPcw_iZ$FE)V+(jHCKn+GPJXXEA7 zfUj(P&9`I}sNP%stM3_9%7bgs+xlySHe6FDJ)zM`+P2`3ZR4?~B9IyJX5miIj&Bh= zkM}KmnkzrZE^7fDtRuhw3ty%C5O0l;Rln%dc;NqRsc-6C+lz)iiZpj80qQLpv^odz z)~x>8c+h`?vlH<1{J)!A69KrIv5Svc?dMpC(T{T$yD&61nx7X{ns?=4UbOS>(9vUZ zb=~gqyy))ba~1BPqw#6JQKRMUV#%^-kItU32C+B1&SR(N^Zc~e&-`*JD*x=q{Qdf7 zH4Xgu8P$z;O!GH&iHWrM7qX3(;8|&JKtb)P4Ec>eV^uYR;a~Db*+>#zfN*07ei#M5 zoU9Ki6oh7DHJg}xUq3?0%ip9QtggIrl0D;DUFtFK9o z$au8(h9Bt2jvnKI)H^546kqKyT^TD%^_5wH2`f2W0c;3>i;amg4rJ5E6mS6{E2HH( zUT7UX9rp(dId_!wT-^YXuZFzw>f2ASXcBO)qMCy%m-dP0rCgwXB%9x3bhor#IekOi zBk%X?xTizD)-L80;4BUxACiskDk*!k{KYSi$$!Mx+fVUV=}A&HBXRs{u@G*hUh<*l z>{Tfoqp$w>1JiCqVXxV)4|;DT&2WrGh>$peE$}&O=DW5%$$Htl(#)%olr%W}Sw+|D z#0uT+hj*1_sE{h_#x9SRTZYxpv~1>ls?3_6^aebyL#DZgT?_Fpjyqlbvg!Dgz~-3m z$4!z>@tdZ$Y(h}8Px|6?yQ-Q0`L08jk59X{?W*fI6lHIDLKm#Bgp=|m;^Sb|*)38c zkORkAMoZ`Ntio(MmMV%Y`x*}U+w{$6pf`2J9SY#|y&FVYwX9SWZR(nKH?UF-5H2{p zp$+)DzKQl&bTOR?igDFGll{3LOfsq2!|Ls=eX`|_%Ehis#)?2m9rd7FEO zM_HLqg;89kk5BDV*>6lcy4*YH-Xlkc0~f{}=a{i5hv_cPm}RG^V$AXhGkCw{G(Y2E zZ9;1X(m{5xDW=5XQCG~v#wTko_pD}=Vvv>>YGUNg^rj4?O)u~N5!f3XWb4`1pu)_G zDAo4N7HN#GIy zmrvm9+p3-oDU;S%9?A*lBrt8)Y^dB7Mg&q+K|&q};qA604YC-Ho1pIg%JM*ynt@(8 zzg~%#NyvOP1!={d2g1<<&4XxLgX$c({)u^B$o6+UPROpu_TX`!1wIZ#LX3pSXcmZg z`#V&#A)oO>5+-j7wJ@jh{AN03PfF z>1NUfA-dBE`d}qZ>4`78EC2`AKF?9K0l&sBZ2DIkGrtLJI{Qz4yRZpa&?T%u=Cxa2 zMp#{2`MllApw;7HzCkk8plF33dy3m#UZ^?ZeJT?pCbhQtVe`GV z3U1i~bS37D?kWcEaEF1-8W`3r=(9=vDvny?49-*#^OF6^$l!<^Rlf8R^L>0-R1Fwc z!5dBgZyp}LGN{&1E3hG1DMYPsG;WCUEQSrx6F%+{JYJqYr3mh(`2=0TgE;gGPz(`S z`2mlEk7j|0wIrx!E9Z2!7`%P9#W#_lq5O_+yiGO_#7twj_Pqop8`_Xc1|MwJOMs$; z`@tMXmF5A6Wms+DgNjLZqwcb0MmcqtVG~%5nhU#G{4~teu>AacR`Ql2t?~Uda8peM zJb(&9H7zS z4U$dX{o<#`L{4U0+-9w69c9x>h^ZSn<7A~@?EJUf+8NZ|Ao(m777TqF!ch3T9i+Q| zN@ohSAC*4f@h-FF^vj3i=vVgw4xiB2>Yq)U%@- zzleQ>1l+7zAC+PWbD%I#Vf2smDAfW7mM4mw2#R8fAvFL`WRjE|u{~y=@yas&l=q-& z=}`#5cvj&|4u`Ss?sS~YK*D3&Z@zBHHP*3GOcxK~-x91hOUc+Q>V^J(G1!J`py(qF zQ}1Wy-5z0}ch`8x^dCqeo9rb@t5G1J={yA2YYaSCIfwT1RykLD8SY36!y$pW?wT<$ z6B1b(B?9U50Z;;o9d7=b*8m_Rl0PBQB2IPvQu9MMizMYn^{R3N)M{{8Ve84Y@ix>* zg2yLvjc!xUl^PplB70?@&ozfDw!o&uFFmFL05; z_jDKvRC7>ZH9gM+kYK-ueojX(ee1q75O(g{VZSmQc(`rEicSvK7$=*j(^-(%Qoe^ROL*NFcxo-`MM+IOk7 zL{M0U!%euXN2E30U59aC*&f~v^{e7Mw2&$Jg^hR4|0p`|xTLrLk27V-5W#^95EWOVp`lp;F5C-8 zriCL@vo6letjlm`h9fgGLsK)eLQ}K(1$S9pH8bm07gyQ1T`TXs*6n(K{Qlv=zkE24 z1Lt*K@8|ROw&}_)&y)E9M}7HkVU?A;F$YpSOgf%=H+ZW`-^wd7QiDR%RW^*}sh^YT z^UCC91Ka`(CUk7bNk*}Hsxn{C$LQc^aQRQ$rA)SHtOn)ib=Gg&{% zyr+~kd#qGQ`EKswdrar-p3PfsdhhRr0fGi~<~EM!ld|u9=HZsQ+vJdfLjbKn+^kOf zTiEu}X4rnXbfH>NOe4z?WcocgJWQJ3(1iF{n8UuAe9hD@9$gX*cKyKyTWI2TdNc>4 z#}#V`642Dyvfz|ro#U(fVI0Qk;54?hzV$10N5+4c{(kV;HTm9uuUQ4B6g7Pt)`kOv z7$=azQO|l)G@w4l_z8IW(L4Lf+w9`NZDc~PeoB3iX(Ps@f~!BtF75`2IkN)yvBM1Kf}4KBlen5vg430a@om ziH-IkXs*PhJpwp<_AC3OmcdD6y2!K;E1Hbh1LbXM9GbA_P3%2s&)5520QzfX(1t&o zav)d|J0!T|dpv+w7%MGSgDRNtI77Hw9@f6M@$!J5T`kB?QQle0we20WtK~M#5GDM` zvMz4i8D3Kd*iAiKUvM8(1$MJT%$WFY^9G?_jzbb^)WD)gwVz`0gvC=F6aW65T+ zbTe;g3p>#@c9B_OuHhYOsDey3*jz&r3;;VBfX#as>EuB&8Dw5@?p2iu5;ORMcq*?u zj@WF1e7^zIFhXtYo@FirtyWYTeGI|3%5+YZ6Us4m7Z-T##{bYYxT;Eg7^NG+(VBgv z^Fy$2a%RC(Wt44-{GvrH7=4W!ePPU52=FFE0|R@W>CUwhqs7n{FH$;Gt$_{16_s`d z*p(o0Tpwb-1paYFMA96UW<*OVqc8&7fGNV#lt^RX+z&iF{ZVqESAdE5wye!a4%127C?_Lo8%9D_n zdmR6L7{wrjj(hpGg<%aE{4N#kUSlu9@tuJxd=e)yr54$x>Tw^{4GAN*j=bK`4$|@y z6COTWG1f9k2mac8+Hgum<8$@i`B{ZXega`FgdhviE-M~4%6Cxvu|Qv-s@RFrQ>%#~ zDxyGzpApk0)p(&;U&x{5Sm~!6Cr!zZ|Es)SXj!sXe2Xqd(>UlEWe|mPkmDo;7>=Z)}`x@;#2JL1*0_gbPp5+~LAo^|(msYS_x`%a7?06f`WkzD3 zB<+X=3#KHy<4@WFX6?HcoC|)js^!{09JHyG*bWcc4M;Xl#iUw49JKXB)((zaTm{+B z5$(Gszt`3|$4VJ4IM?)7n^UAr-q2=uDKh{}B}>RNsZ^T8r6KX07MDzuoZqmf{79AS zby6W93XV%944$WQscLJv9=A!4doFE}$&*^Y)X;X`J#T(Yl~?{#VB?&v`e#LT5kz}l zcFu$!YT7vU;)MDiI|-*nV5;F4r*McBm;>3Jd)8 zBS-vs!h7*j@#3A+;Y6Q|4*i_6XMLTI$pxp6?Btkhh;BR&*H!R#wTCN?3!n||Pw`5uI8Agr4)m4XHe;E1XJZ^0t_hKsErBO>zf)xTxV0^U^j_K51>Kc2hj!#zb;!6ETXiBP zpa$IO)n-5J1K{lBiQ{?Id7V|p!x~|kiXfh&Yhr7b_)$kqEH5yiLrI_+i8Aoi=X0b- zK#LimFbp&YdP^(#{aGdVSr|w!?{_^OSHIriP+h1VpQASi{3DxVl7hF$-j|&QJe08S z5XS#m&@?dxpMPMY|Biv4kdwj{lVT%QlS)eRVY%Z-d!e6Q-iqDiV2K}S(@PZ~8=x76 z)p`iq*$q3&001|b?0Y#%tq!339jldvGI%+yWZ3S2trH2c>MFN-=4M?o!K%xb%!{k%RlEh9RwM##_`@lPUPd_`jwIDv& zR@u@a^w8L{I(qN_Ju~<>tF)}u_k{b9wL926sna^Z<7`5wF+`FB1+^hHE9=6AO=(RS zNz0p5P;lBd8%CR78kL@Qj>?tWC;-pQp*uRaQ9(=rL-SR+I}L7TdHfgVVG}S!WpGg+ zUj5T5F|}*GR5v!#%(d= z+vxi}O$c~>&HJ?=d)kCWG%nRY0NMRvKou`9Mk)b0juKZ~*$%iARpG(9A>Szf*3oe*wUhCcR{SD=gb$#+0cLqgC{&?5%~SH+18>x%ol{h zwqvJe98MH^g#=8-qMr!@aBFNm2y#z!t@D?CEbPyVDv$bUvdRAf)yGrJD zl9%4?@y~@uDdHDXDZI3EBncTJp$@3dTXOgPm-^qv?)RH#tlddN%`atVejw^?9R67~ z{s)(&eewQ=f@_{KzT4}8O&-@q2(q`6z~88-+8}tc{%g{E>qHdhBPKvK^7u>Hp87qtGdtB@zRs$hWR`Blu;;Rq zZU9H}pOZidjTL--_*(maaL!~jWTWQz#O?K|Spa9_L?O=LX0R|YXfEKtzC8h}6SJ$z zFRrTa-ZTGP@-m9<;hKq_2*f-IA}B$2>?7H1cytQv zW3TN@Z_ryNeRxnb$fZZ8+A0S>>vhob1~c&2^in~s!Q5jBYdIY`)o&zL1rAj$IsZ&Y zJi0V!AN$L7k3Kpea_7EaJHFoEyLmgY)Nz&c$HvOTc&y7XH}*8QlklLf`~kg|YopPb zs!XuJsdp8lCLKT%+(v&6n!E3^%eJOhlZk6PAZM9abXfKf_1>#q;DtIpNw<)nntHdN-MO#uSCOP^q{fz0Hmr51# zvUJC~bh~;8+zXtYA3D4rKwTX&D)HFD5Syx{te5Z|+}v7Wo=Ll;1o1^C~4r7l2>@f7b#c>`PX z+A~0cZveU+{u^0)?mxde@|=Kau`m!ddp(Oy_vrA7?O?~Q7aeOx?`dhr66s$R<*3fLx1Nv_Z%eDhTjo%#l0=yCL;3E z&a3QGHmuYcT2S>z0*&rlLpy&?l=flE064teye^}e#$?>F8F+!Yutv{ z#FN>WzMgYKMYd6PvAqdX31>eXu#f3apK9nk9cdfY9rL{R&-cWxD9+b__X#CIjc&QC zE6q>5f{ktV&#J{wiXQ8_<*Y1C)$EfsaV`d);kM@EsGem3u32|qpz_cKoy$4;8&e*AN$+1%%v@s{_?QK5;a z-lyNX-j(~i+v&%5J0mVf$6bi9HMy1D_4rFtS6151Hx04w%_9vleyLC*KkT#TsW85n zfem7Hn``K%_E*c+e2MUaTelS&HmdaSyGSl!VH};W; z7}Y-GcMo=I9N!PD$Ba|?JP*>oPdqRXHk~pLAvTPMP3Xmh0O7=2d2y#7rt<%dIL!gJrCR4hfRh%9fNlx2G(GnB}j`!_D$i{)7jh_AVeS^EH)!%hpnAn05N^ zlu6#tClbm{MM4R}tN;KiHRaDWmyjC|)IHsQWRCN~M6?DmcHI3CH7$f69=@YmtPjbtTb+{URTehE z4du(Hq7UBRe{EOKnYL3qzxesKh1@}ooQnJiI9uyR0|sKZTHZWUj@6$}(n0G_cpP#| z-XpEt`boBk;Xr(8vIQO2E4`_`&2J3Dwc?j>iyOq}_OkiHtp~U4$eU==39oQ0)`{;q z`P#t0F@D%E?P2&n(?15jZo2(@|K+zPnXj=wo!$TCc6(jcw50N)Tq;_TmAJpWe*Ajl zqOn$Ik@_!0iY`b`4o5Gq=5giM;IZDvJFN&g2ktOz*JROaY8tdtg%WGPEJof|^BkI6iOuQVrASU_#yVPX7wNzW`}%>UqnNOlT{ zXp>`S&7Q9bV1aIIH@xI;VvA%m>bk3Cc2?;VvA;BMyJUS*a^CR>l2I~t_35<|3M%TR zpJX7{+TM&^p|#R5JJ4@&CGxXox?ZM#C&BpmuTLI_1pYZe!lSr@yje`hIIlZk$Btd3tLf91mv`2#h;zP`lN*>H*vR5Xf z68x>p*|INC-;jh>Cyx z3_TOElm%{Q%*M|Z7NO?8k5YrGt35FX>OFd=88GK|M|6s2 z=b18u9kl!WoTK?6ywe6Y*6W6Qw^E5+<5=gFEKa2MRlU)tWKn~y>c!J80{i!-ZWLy_N_u?74DqjljmbYLjjs9I{gUSDjHKSRwVo8zVpU@1j9 z@C`goqLV*4L~2lj^cw|Gdr^YHw8lTpx5h1{aLIDcN!_G(<(kv^KE6moyu*Rgl-eeh ztHoP2bx*rUjketyX6dOqbXXl=yfWGBMvq7Nv$^s=LG9|`o^4Rx4xIxc4$B91 z>IG1n^z!J8XepteTtqYABg2KQ*pw;8#Lgt>)z_3i zZCF4?4}h$Q|1~^I;vmW$mI;C!s0C}jFKFg>%lZLQ`1V^xZZC82Ni`22Y|AcpXC$Le z-1N|o+zqqk$44C<`9@L;VeYg+!zaSm34t89C-Ku$KaKGEO%0cR&Dtg9ht&KU7eFy6{!qI=^nb&mk`Si zgUUd*7-jQLM5i}bL}(7zhYl9GP7fgEX~D&}IlRPWe6!!6t5(fm1yo>s=Q zaQ9JZOF0*7W8G=p#X+TodfxLf-~&juSsZHt(+*ga@5ldG(Bv-~B~qV5>`7a=3I z6*@DxH^z+-hbh)F)!O~B4VJ=&k=wPZ1t-ZCNiT;3*=>$&FVrUh;*<3=ce$3kK&yWe z>@J6=k--clnEX4qKSV2230BgeF#>Q;3OEP=Eicv}#sDq=kf}&B%!N611)MC9C#jLO z%JX^J$eC1Bzx>xaK+-KXa#fDUQL*F2$uU05>B-5sgjU+T4n1Di&2hh5ylw(_Nx;!# z@6&UiC&tVZg+QYpTq66eUNVr7MA1v@Cc4e*CaBW*k~Ho-t<#clW0bgCrAyzWAwzZv z;`JgymVQ9%rT4^*x)uUkrE3|FzoR5si3$Bm%*R0diU5zlfT<&Eo~!XsT1=4yWtw>Z z7!3Q7Jn)h2^iee6eIeOP^;f{jO9JR=-hhi(mrq4q3V}QjmGCmFMLb}I7Yn` z-{@dc1;(wn`RIFGAj<@x{)j7&F|#MCIO-g zFxD?XS7<*S090!$PNsbVTnR=+Ud6u`;Ct18WyIr7KtQY*f2o8`>8s$MM{Py!hefU)xV8si01KqWwR{@F;~hJ_;JcEJfk#5S!h z-{?$_H^@4`c%~MX|NbRSg&L42xpybArgTqB2+5qQf_Qv>C|;m4*eBMrJh$mi+ArJ)dx^#@iy)@m(g#BDGK6&2nCV>#0V6A;2t-4n8N~F&ddsGf?D!mw>(Qc04 zR5I$0s0}Vb)rpetv2_)px_8LNpXCHblCBmBd;N$zaWyahkc zga?qN7gmf>tO_48Sd3L`=)n=cOFURem9qh-pk3A1EAE#A3zTnN62Lmc7b8w+T0%fK ze6^9>qc`R!M=r42r$Kf4KKaZKwsz`=mq%X=F|m52qwtkeJ{N46$k^DdrPz%m-f6aat7+nkb1r zV!{e}uSIC$XF0|u-XkFrtDZ{k6lnD9xoVv^$&XPBWbA<<< zt@3kWI7jU2znWX=pAo?8Y(#8!mAlpstYXWRSpXIn-un>IJ{!J}h3ru#!fnRk88~$J zxjF&vy5JzOD3P#tBZ02<6DlOE_v6XG;E&(9xbz8F-{?sd+B%=Ve(zwTLhLcy)mveo z1;?(ys|APxZG?IS(I6kdUd2HZ_Y>UqPobP$~Y6t9Ee%{ipqmwX3+in&PUarX1X1WTa)*IEszVpL_+%P|Kvun`0|~q|-HL=(Ch+2Dc>#x` zcm5f%Uk7{DJH65c)vreKAJg9h@X5)#EzjCx7?Or-EZF_f> zgAHn~*fsnI?U%cD9BkTfu^@1Lm9=emSXV20awVu=6a!P z0bz}RQONy7xg1kCW^DGEKwcyS#p44gN9>z1yZ_QbNR3=o49%5`6;Q)<#rKGfQh-KU zQpWChqqeC0L>H`6WTLB>(Ai8dtv|hMzWrf4G4mT?_xu3?=-_ZNe&eyD%P+r(c4#5P zCLLQCfJ?5J9>5RnPGq?q)G!M({rSzyYbabs08qI2lE z2K0L4MDD>o2Tm$nP^De+^^`Qua_Y6eT&8w5VJ1aW*Z#Tc;9?Yonw5hYWaxG^xLgU& zRKncZ)G%`Hd4QJNX+(pdVutJ=!&dc-LqCP6_-rs!3D9N(-wQ#XWL(2eFp~@k``clc zAnVBfT)PW^KkY)QmB<)7|AMwd!Y-s=zZDCJmFbSXAy->*(5vc+7$b)g#!z^vZw5SE2Zo=bl%c=OeL$k8~}> z1PHZ_EyW{QP>bcE4$;^WX-vBwU#lN{JRQIAZSy+;;uagQIo#=!@_v_Sv(t6tYZ0bP ziTotj^x}KRI*_wUC1a@f8cgSxO$Nc zZ^+%Y<71~P`0OHkTblr5*Nk7?AOmdY39NZy=X~vjzX>7p8Yu9UJ>q&*6cGU?b~T&q z786Wq(umMH;JTrYW<-B}__sHDU1FY9Ga;?TvbX^MQEBpl>^Xi7laPhyE^k)f#IQWZ z+|S@!Um9oV^?dEcyiq3O>YX{dNxLH-ToNEaZkRhauX?mMJ-z9uE~sI9Evp-@e!iai z)RiO)dv0L0w_#9TKLw=P;VOzmT14r(R_$Y81(z+ywgL2_U_X4rjC=3=ufM1Y%%woWxvjAlLE%+m4c9h0Fgeh(?kF zz7@zL8#K7WfOnC}(4B9dgjEaXepwcz$-G9Ok6YP@B#TD}8A)_i>TzwIW>@(3RLH;Y zV$;=45B`O%X*ieNFB6wVr_{)FHFkOV46hAtnXGdIfa2S{G6iOAdjl_&Y3V@`?H|X! zc%ss^9_eR1W!^;;>TAr82m$+vb6jt{{Klw;<|Xev3CQx@1*5)Oh`8N*eHiKf@4^!` zTA)U2`RgE62IZSbwHMwG1YvWB-pPhauLS+RBv7;(u+q1G;Ggpp=c%4XQ4DWsQj{4LRDcJb8<0BkE8(;L&^0NioY&E@vtw6}n( z7nNr(eXDU0S4Y8yuT9~jb8*7Dbem>8m*$w+HDvLb^cgU*=PSwD2~(yp6D>AGmy6(l3y5(x&1hI{j{#1SbrNO|3tQte={cCPw{ih z-tWw0JXrhjqs_n;B$4JU1UM%{cO+&IK)ta*ENz?85Yx1f{#8Z-sh9`hudT>%D zJwc~*Jsmw|7hPdT865J))ArZA4D3m|f6A6u{e#y+V`43~Rjw*_c_QE@zD*&aVMQ1d zn}~ivcV<#^B}ublqX!Coxs@pW$oUknfm1g4c$QDg*YO7DK>4nniA64&qUmqUvYt1A zd0SFnLfp=e7m&Tm@e~r19@_HtEs?^d?U*8SzYw9-!SMWzGDfBu_0WJLddH6+ixSk2 zS`Uot;_4K7$I{A$lgytAJ%S6rv%<%VmefNh85Jv=AMaRR87eheUX^OvrR<>xF1ULv zjPpt@xVf-2S~uQxz4YrSb(≶53M9d0?V6NEl}nF@TR)7gj=Hra4`XG?$a83ogUb> zHclh06aieM^_5}2@~dsysN=MA2Di&hnt0Pi8~Qa%p<~^x**#_5=R&@e(i9=nzBh8# zV#;W3A({R+YFG0QbX98(`nPk0p)iZs5GumruiP-C-8osVrN{D|(@PL7XmFmncX;8}#&!r0`A$|BeFC(c;L zjF(j`!7xh4>v6d8G6c7cF$E3|T5Ht9?O0JnV0~%8!o9xQCPjo`fa6Gr#~(6&Tujv| zJ#3P5U{usrmgBS7z4remW5WZgAOybv$LNssKSGAEEVW`NI_9ZjD1ICJvsB^xUGpGgg)!;HSTsLe@G{c->g9^9e$V*$o#WC zd~9>XX9)goM&qP+(N5%^MplX13kUW+Ic_>5%F+3SxO9eU~bg9LkBHE!~u`%npW z=82UVwZb5rSMaw5b-Vz2Yiyi3%F?Zrtgn`ejIayRlT2Zk0cFQ%FJ-B8osgi7Dl5fJ zhlEy)SHa_DV6Nlz=#RCG=FId&&t)9UwZ#!}U>I)p2fW0(SLZ|u3rtcmitIt3O9YT1 zTv12S<_9{JX>(L!2YrP8KBSDZM8Td@3}JlZOMK@z{+K62G|#qn*N%8hp1iQGl~G~` zp_JuIK%Gsv66=`_PUz132>q|}_Tp|Nwx)i_VnqqH>`BAy>KMAzqbRaVq9As`PZD3bULUb*Z!jtMD(4OU9BQ7 zy5&tobb$Y!y53Pcw~XUE2e^CArFhv2;u+E7VKrA3MYeADuN1s1$G6rOS*=(g9C#8! zeZh#mA|CDkej@sG>#U_)vu(&ne?3ihL;NU*yae4MZE#@2$^JxLxb@)_!#2l?U;b#zG3N^GE%Am-K zr576KSQmF}Vqs!KxbtfIxEu4K>PULpnFhZym#h0n&&mt$#fTM=Yqo^ZOfUM+7XHDq zL88lReXF0B%OdkY1gzA{dbclf=N)Rz^@mBT zC5)ufQ`$eHGgb_)G(k8 zw)@E+Y=|7c zZTu_Zw>QyuE)I3O_?*NPK}}b$UT}K%E*>qJlzs(O-Te>ZycFzDtL;41PxcH3Q;ZfeO4LH^7me5* z)HxRUK%#7&%(MaQ3E+Q3Nh#D)1HDawSzsnUF$gExfvv8 z4&PV3|5S2nd1UG#AdxL3wY~|?t=h6p26A5oGDYAQqhM1};T63IYn(nyu!Y44O8|x> ze+Dxd`mB&GQ?p>Rh;T#^IMi9s8Fye)U(w-ku(N1O5X->f>HfOeP-Y+46aZpQ2BR}1 zZ$}-?>BjEygezjP4elq?C&?y>Q4gzFfec*dqHgMlPSLufW0vt+Yh}fy|)^7_{%!YF$)St-c z_0`NJwe|r9hVE4HSPuW=Lgo)0^Pm!})tI?!)Q$o*+mwo_m^63%f_@^)qDrtc>g;7X z`Y{<@xT4_-vmGXVR$67yUEKh8iNm&G-z_oO)O+i*sp=y|(Bpt(Y?`8T>y zFR+_}wCa7vw1XfR2?-*gCa$gIDKzaX%F~6^B{xn)7P%{nw&EbdX1%=d z^sH|1_C!6CB7I+-1KuMn_fWp$8DEay>ZR7*h*UQ{bCmRlZ>2tg zT*)%t9539i6Fd=Gss%I2mO&zMLS%f@D)Y?dDoBJJ9I1{lmF^n@!YnsC3PYm#t6B_U z&}1T1D2Gu=o3T}}d>M?2giR{3^g(R85VWD-T!Es$E7>GaR*(;-M3m{e1@oV|_J3rw zdl*K{&UF8>U5umt0|;5sf{lKs)IFN;u@6v(w`)NwZYa1`=KG_p%}u}>*`7J_;wLqC%v zHw;i6tB5@StvIzuoD6nI4vUj(pcSya+zsn{hpG?sLMMA5yfymuXZf-)ON^!U(JeJu ze{Dk5q_lgLlGGxDMH?dq9?(HYB#4g^X*A zW6N)mJ#3W5qVd%8=P^9;e>O+2Ku~A+$T_)ROfu?*z<25N&sSwX*DPVKu#we=0{Ug> z>wwceLev!*tU-C`CpW~IeAt0EILWZE8ZX|jhWN5mT>-j|RT|90E`I{2W_uG#!R}kZhn|EQ6g9dsN?g^z zpF+Wx@9X&}U4lZOe?%S}n{Qgz`zcUn^*d9CJ(>K3;&eAO;uYa1lc8WlGZdI+{NJqO zgev$cMZD4u)g8cLJHf`?ke*N@**a7si^UH)=}HHxSdQ1lmk#H&lN}Z&cI#T&pbS1fiC$Nu^99HX<;53$$UY%vMTU7RIJ9sU z^Nol8lZXD@|Ki5dS09d~z@YY+1{t3ld-_ z`7+R(LztSXN(r=9U>VDX6?-OB<6&`%Bl)zN4m>P0TXkvBy7%I4ptr;D8Uy4lU=_mt z+PepHf$@rpxS&SXFp{4|Az%OP&@kg)2e|?@M|I1NeTic9Fi~$|sJlYcvKy+Q)^8^1 z>Q&*fegW#b4Bg^+OhG{XC_&3flBzlIrsfnf8|Oy`8+4?E=8zmHg=VX%hH7yo7^FP{ zcKFN;7V-VRLi}sMb&g>cBVdcZhGo|#9xkmC`T!nPeTmfBT<_*CoXk-hoxChVYTXnaXTY-+@H2iv$#Irv z{EUnOpdK?&Xjj2Vxy7W!@&eB8;-tfdvn_kZ^K+Y0An~q|RPO{p_0#8s5U8vNx&`4A z3-F0j#tUaQacD<^5!FXe+nmhybTagb0zl>Ot1v_MFQ@(15*+T#_%;LToc}s?or&n5 zwpD~6=HzwX4jp~UM9r=GeLIATL|nnWL-?i#+)zO1PYEuoVlM!wT?=$@9@z3D%|=9Y zW1I<#f92E+HpPJgjk>Swl)iNxF9$OhuMmAhz`W3~jeC}c{)?|elp!WOu-u$+btHH~Du#NnHnUb{IzsV!M$(;<%bvF!y{5bp&r}!mO@pz@s@t*+4?AZeFm4f4n!@EUL zHFMxm)RM2-Qn$enTs5#u1d+*@ zyg~-scT~gm4xEQyUFhbI}TC|D!FGEN%b?ZQLXleL61Ay6mPqXY4?d%7#T0CrWTuah(lFdn3F(jHP}=Ra;=2Cw1)E3K;Jf|zd7`PF7&0= z__0uk-}Nz6l~EKM;^GVOn>6qiY6?XVPcqbB0hT!$`U#8A#e+Qstb+LP==k)Zk1hv? zpq>f{x)|CrZ+O4BF-CR3PxIH>qIl`1UZ}-`n$*JVZX@^y$bJ=M7Y7W=Y4TOE^g6(Q zACU6o;J)x~OYq->4p#GK!yN*!&7{jxyDQIGH?HWX&GX=g8IYeXHwUaf3}F0ao~7b# z=S8p~MBTDn@}>KrUxas|{;N>Qvw|$*zcK8n9|1-&fPIh7LpwaS1~M=kLJp9O#?j0( z0nm+$%f}CFVE^#IJYhpUI?=9M;QLok=rPIbA$Ni69oh>O# z!&T;1_8|iqxG^sj^FSlBbeceEU>$sMgFEQyty^4yz2H%>0dVr9Pf_A0@A(>)*vU3~ zR^WbTj_~Tc&y^l~1F7xHH`7L?9R!X7kq--+z%U|Nh2v$&%6$YIVN~o z**$ciI%Dyt|}!c2lBk6WP3iO*Qtl0$G^;fq|wLIat%>tk5VbjajlvnE}TeMu)=l2;O0|Dj6 z#cPx?7-z{IW|qPi7MtbW_bo(Zjp_^-Z}~#@k!Qf$FEtikxamPE(5g&EdyaqeKKs)H z-z1s%zUnPGYw%2}aUui}O=@WOg)|&bu5On#Hj1?jYYcpcPwvS4bOCq1v265gMxnK5 zgS0|ySjEZh@igcubR9PIuZSNu2)EODVBnwaMtkSEfj?Z{>8!bd#JF`$4^?`G8IIKQ zU+i8w;61vKhVxdF>}Gnx*+@`kILIS8I}V{b*v^VL#a)NxEz9?}Dm>=;GyXzANbqJBTR^KI}&}s|H4Nx>AD( zJefi*JJ>8|?xJV&dSkbl6}239a62-L^1ZhrLHXtZQU6GXiSL#6gwFybkM#IPgO~Bk zwQ-d0(&U%?LMTtbks6aK_48ktSdW)?{wO8iPCmldmp^sW_tkbtV&x(Yc+i4S%CICHIY)hsa8MuWPa z){zO?tbPb!zUx3uyNb#To6q9YUkQIaky@>=5n(JGwqTNKI^2xP#?4;sRGQJE2)aFiVY-#>ljesljGT z20rjxss|Y|bDXm%%GK|Tmb21{8YqBl2PJnqSC}Vsdk&MF{yGuu2F5%lmw1jgmz0-X zs9o1M=|{y6Gg81h<8rNy2O<4^sXTKMJ=~BQK-s*HCB>HTJRfJyYBd^g4W>#*w-ka- z2;m1vk69O8vwC#BS8H_s$<~@x^w)6HwRIfCJE?l_wOHd6T$yOdWEN1u3b4TL-BoD1 zGJERA^1N-l;c)DlqO|8+HX-)XufD*j{u8%mY{Fc3EtC>r!&AVFgrQX9_qa2U?s#Ws zu}q?*-S%_;KHi=s;YoZlN{% ze)j&HKz#N-rLQm^GyI|O%X?*!kKM_M9BCR~Sndr=rilBdCR zdg_IK;eg;9PR?U~kJ3t?vi3in{0 z@`A}x4l9}rA}B%%X?F82|Er3R2)=OBex-6}46;T0RkW=ag;MNqP}w&3?2^r}II8(k zN`EqO*vdTvkrOj@B{s(<_^CC8DCdUUYa5Y+*qjoIPN2P_2Bi5`k9)NB@}WVz z1VWx-sP7)R?9aSGU3*i;&wVpRq@F3}cIG*J*D7N>{Cb@Y5AweU&Df*@avg{Q)9u($ zJdhp*!=>aB7f7-;IVnJ}B=j~ms8eJ@5^6yc(R?KJM1@!48Q6y?YEY;q0rtR*4)C_L ziAz0rVA2y&F>}rupTtf~7BjRcF<)R_17e**@-R^5BKO*=v7 zuAa*)jp(etC(p+1cmI@v7}%KOapq)YZps^lK~FN>!yjsE9Bd9XbEBXFw>`?({Faox zpM#3OZJhdMa>ni<$9YF7*kqRn>DQEl$oYJ(-YmNeQP+>5(whEsQ*D$5DM9m3&J0@* zjCAdMddV}J%Z0D-&wCV!6erzkDn7Qh?KyEsi@^x1sALzhSRAj6ya4@ANZ(B}Kx{uT z^wdJOZGcnz&S>7J)6o$S(>xL1ofU%nK(N|HiNEX#vp7aR)q%buj+K3n_)aFq*m%aX z|5wf>wpL#9o0lDpd&b4T{(B7E9qbVP9F%&V>}G#O)e%xU^lR03$c0VC_LJMye5!BR zlkGqK{Ii9-vL@urEL=0q{9b&jhrKou9*^`^bD!sA7yaGv31;gzbvl{n4Yr3Nf>u=M zQZ~e5a+mM&y+Qns?@_*gVNu?KS>1B-E249V*G~4J$@&y%gJ58|$3|pe`|Y!qt$K0d zH#XO_)&aLqq&5US@vg5TGsN?W*6Z|f_G|WXDO?*PjK0VD67p|h#?RcKzNTg zWH&T;$QtDAGXB}(ET^JlZiOv4WHZl7{H1{{lPanjRng!9IXLFH5w5~0X3?2c=-k6N zTb)o?{{nfr6nuHy+D+yrxCFAH8+VGp_JaON2KQzSs204%Z3Szw-n>{X(QO7@{9f(1 zvuY~j|0p`osHU>^?UT?$3n7g_n1mWyq!S=GDU^V8Pyqo6AWcL-P(*anK)N&)5ip?{ znga-kScZhEf&wB2L`?t{5q+?XqYjh*oA=9EXPpmONzPt--}~Oz^}C#N^vc^Uus9Wj zd5{KfxZ%k>u=`h|r!-i8VVXZX2b?JdUdt<=T`n6RWy^we(vXUg$pyi<0vXlK(xU9& zUl*-o;WEI5k~!&${l&xBOP1h9!f3S&E^0Um zb$tMZV}Y%9jagSU+Cq$(K;uno zlVPS55>ag!3)bUTo6T{D)>#(WEyx9%MbbT#O+k(-gDO7Hnp5_eW<eiqJ0#SLNb3}e)IK&j6w&R!oK0( z<3+zW290v(MyY&~c0zsYvZ~@XYI3!BY@>KgUoc2??6~cxqORba!FEJGbnIpI>FT*N zK>Eg|Bb%T`q+-{!p#1$+a;1LwSDbsSktMiwt_&ogmlDBOjk5J89`N%*rMI%Wpc=)! z*IjST6Ko8Iou#~9QjL+>x3OYjP8jq9P4A@3Faf=3IUZR-#1!Z8S+jGY)2C7PkUj< z96MOVcVsw?#rEy?vp3}$8omUiYol);f*#*(5m(pZ2r%<6VC$57nxBDM65u^?nUYZB zd9J$CF3*4gy!aGr>dG?90(N`>o4A($HjTB|tVZQpP0Lhp(hL^P25o1^GNz%i@0k)& zgJ&iSpQ8QrHVb8gUev2VFPE8+D&bgdV}^&uI#aeqaE-3jFlQyjn#8LxjuVh{-{vB1 z<@~zBS}Wxg8|C~U1z3@Z%?Wa#`xQr|;sI9CasERNQ^kWTU$3J4U8fLF^4(TDU#?x0 zl2sUtQ$9hu>87V}!sdJuSJi7+-ixNh%c%cLg;P!7cL{?!cuw-J~zAY2w7ymr_;b^1ydQRKHY)3_HsGoM)Qpe9? z^>I`qN$I4o0y!MCafjdA2Xr5Gl;y?Uo(D-s9a{C?|_^IYkT@+TG|gZC?}Xct#% zhYr7M_AtPozy}Id;0r?6FN^H56{7`E`L!(lbb(ht4+OKmiXWDD3opGdm_@X*eQ4(H z8Mpb|4*qP38cc79(23ieGf`<6D=)pDYLgRfJ{A1x|n z$cG}uS;OpLr1Ej3SqhEw^{Z+0tb7Pt)dQp8N<&e>}8M4gAA<9$5Mty3rc0!r#Xv;E36KqZCnp{H3?tlj7)UfJ%1uh zYGM=fy?DwM=7tEy=m6VvELaY~vgJ`{4;^^-YtI_#cC!)ibAy5mZ@9vjxAVQX`oe0B z>ptTKdr{dnR{A*WO4Z>Gre18NZ`55AAC}>d9;0Ez37iQmTs}WG7C#15$M28$cvB-r zp-@{TIT=`re4a_-c_8~)Y1!Z8}k=v&AKa{*Sf#2gtvup|p8HIaFk(zjA$T{EC zAB7h-+dVcpL(qCk5a+I~tcWO%rCm+t)9Ex$k4-qmCj7*flI%m0JvqrLO3oyueMk<~ zMv?JlJJpWdEl~9m>niR#)Pf81d+Hg#t~bfapp4HM)%ASx^T7D}f;FkbHPJ^F>ma8d zcYSAlUR|NfX1#18eq_01$w6Y{4t7d7!T{|q~v8FkrDA~2j~ip^evKi zD+E3xpS{q{PScnbmb2Vz3*ZI&HPbg;; zdVjxyolU;8Itsi;u)+GcpXadbFEaIZdJXC-?T?hd`U%hr*M8ciW5@?^YhFp}Cb zf~uiL*b(b#xr5d!nLi3GMq>-H=gEBe6R~y~y`YpjrBWM3%ewna-ovDfE15LT;fqBX zG{vK_BxbRatBT~f94ZBKb&q8cDV)toj&-k7iuBa|GoawgkHX`O!lAyuC^!XR6o-T% zC7W=_qiZC#5>-Xv8o0O`qtH*Xj~q1Xet(lBW~bOFKC*A?V=xt00{qlA9PTU+dKb=7M@Ot~!d>ooVvOAA7xGTim$w+39)#$b$Vr#*`?ID^1Z&c6wUJp}P@e zP}l%m!QvX{Ag&;AQ{e}^J0Q0(G`S#PROyG8R!C)@&j?Gd#6_yTvWc<1oy&TXHi=z? znezRIik>)J0`=?8(lz0H#v!eXLFYf>M*IsZQ0A8*NpH_Z4KYgGvv%A>EIvPDZX5SD zQB7d07eG@CsQcV+uJ`T!o|94*d@NxPuE%F{Vi%(RdC@83(WSLH){5`Ttoh1WE)eu$ z*{-a2wzxJwE!5%%1ov01bXv-KqK(FJ*|W9?oEWh6i)q5RYe3oj8u{A3|Ln7H=YK-5 z*GeAW8LQR~Hqza(LG8$4rkPV)0OM)E`RwV-doNFvvbA-q4$b$PU|1`lK5Npwsn|>C zrv|dzm4$#+kf8DCxa58Ce`1rcT z-ZO8l8z!iEP33GtKsdtVjg`e+>YKRRnLog`3*R$VoI^{K7aZP|X?Laz!Ma~$dtMAB zs|2!OkJGhFdd%_nFx?;CPdW`ItDE6hWEnmR6<^1a>diylbX^=1RogqLMPL{E!jqA! zjW+Xo7nVaUBLD2>Srnu$b%(SipLKrHv)C1OdMmdxwDW0e zNRMj{P|f7R=DeazKG^34cc6jGPF&C5Yy=I+Dz?)g0)3&<14J>leARZ=B)aDXi?kuN-??-B8`(F)6gy?5T&*tmj;v9moNQ-SEOZ z4ypke75Li^h6eaY!y~%qCy>#}x+jqFKLhM-$IDMgUkQzFdFg0$erKXSG*P$3W_+=( zrQGRl&{Jo}m2W+|M&J8TZBS$1_-&pLM2PCxwcyp;uV;)ly~O=B8Z{^m>LagWaQLH~GY z-C8MYaSw%m1zA6}KMXSe_w3pO%hmzNKJ)9)>!?%@$Q13T_1<9>v)>@ne~Qn4xJ)p# zytXM+{a7jkG89+ocpK(!tzXb-+8H+s`>eWgq{Vdv=ZTNfE)5>QX*G$o&6;bK)fOab zF~f6^ZLX?q4ojzAj%_9jc>;_Gb#Cr>m^+b+1Ez~KSAmO=5jj&km>_Up<+Fy)N(E1I zrn2u2yl&TF{C6@rz()bf3u&{m##yTQgtH*AZ#u%w)j<1lg*S&PyGZ^Psv#Q!?FO!* zsXhQC)1_!OS8~)hFv?^upf^=4kKvLOd8t51WrhMQj2!GY4XA0KV;k^e)%@yUSGIP# zEl2|I2)^KE|1%R;hCtIdoL$h#;NdrThupqZm0iM6rD0+tFo0A?94q>$`iHL3R{sH$Z@KzkiosXGxNZ5dC?nKMnfX@;Q2rfinA#Qb(gMybIqEl-S)8k6@z~CBcst>kTa{Xh(Ce14 z!87edsanL+C0S75o3s#UHU5ENK(|5{IliW#rzMwp7v)14n#@nSn zQu9QVG5@0Zf{Y`J^2)e=e0A=Hqu+ePercNg$<2(Y*K_kOInzey?pXiz1M~W+tyI^p zeTwD>$iD@S3sFNm3f4~Z#ozdGI?cUG)~oBFwAeA!PVUsVWZTm48N-v?u@~b1`wD{p zy$AMdecWugCE{71rfZXRu!U=v+CLp}=f?ZSHiPuk{>J`K=NcH}<`!JCo7oJ%Q{-u# z*Hsd=Y6#~~ds=4!K+%iMy?vtva{K#K4kkD2ToY-TIq_71sfJ#4NOAoC8I>wgmUJ4C zv7o6Bs+(DBYSF`R_+5r#M$QAx!~GKQgKup~bbxA8D+^M0QPFCgsaf4iMGwvO(DpMU zd@EVk#2(k(k8OpV?p*8pGEm@{EG7FtOkHf|a@ocUVr6%(K%2ysmt{C98;F@&zv|cq z8ze=G9kSZeoC1spQ1zE_Aj0M*jr8eN!JFj4ZicdJm_+kX(O5xsHZNsAfEL!vV%fFU zt3{~P9GL!6u(V@zsKQ7?onbOH!c?e9DAO;N*-Nb#%1t)`DDu@po66YYn*0|1Vm`}W z$yDQ7lH{PljoIaL*gzo1BJm@40`>1RCM$Jdr?DA!l(Xl_)3j;pz zD)NS}8tiYXz?}rT!7cNg2sKFAAZr|c6n)~*s?q?XnpUQh$Mg@bS@|$6w=n|`B zY-z+m99+sJ3xaOZr*|{N(F*nAlBHh;_WB6ll_Egft|HG9IZbkDuJ=_&|MD4fW}ka zi_6r?&X^{knM%PV3N%|5{w5w)eUd8F-c2rmZ(_>#&;3tFhIVl8#%ksDRvPTc87+?E zXyg#>g64$QZ*Q1oxz#nRWc_6My+{qk0lB>iel99hZI7`%d2+hizME{at3GGU zYJIstP$=W&MCymJXdDGJM(9VQV#COWe6B4?fUadbhmRA)H4a5=!iKIxB}?j*OTDk) zpQP$BPZ+~A)R4;R%}3?YI7+9Y&>;NP444vBMmpad( z+O<-MG*cWsYJ!J@)(~#Wh6I?1W zflZvIT5U{QC9vcD1Bo+0^SeSz!jqgLed`43S`N@DgJPa66&lVgAsAC~XN4}0m(nf@ zvvUqeXD;Enth7GiuNim=zeV;??rb!ET4>!s6`4<2i&z$`lz%o^R=L2OaUhPCV4i!y zj6snMD$AlOUvvoN-Zz?U0Ij1oP8GD8O|x*VQRYz_xhbuGw9c7k2xEMNhneRQX%zEC zff4NyJoo?eJ_n*6<+U}vc?K{oV(m$$Ai`g59Rr}#D$!{%@C09YGN6@q1dcXggtmHHz<^ZqCfQ-k-5hc?#-quF zEr*nMv6&%YfM^c&RhbMNCb*l+b}pkDrjbpI*)asR-W+*6W88WCbrPWv6;7F?v0X^= zFm)l=i;A&gn&*c)KuOpY=D zwYc-;(M@A6<5bsfmSbAOi3Y0E=$Ol>JU(reI4=K<(c&0RbzGvH-b6bF1WM+Cq&~Kp zEIcIxo^+P{eY?YS`jQG@>s&I$s+D4e3^MDcxTKzi_MxYz%M`?YLzjc%m^G&bhM0&m%S`5=56r;?K4WmD>|tBo8)!CCE906D zLuy%X%mfwH=(LV%?lNRpS*@MXf85E?sPHQUawUMwm}Nq<4@3Tbg=!`1+W^#!$=dwj zpFP(!UkXs&0>~<}1d1tKWEv=N4g7?^GK4sXx47-!3NSwC>DT9JrutEj{+$=FGpvdf zVqpzHm+Uc5s1SnG7~@d?GvQDORQ)=kbGC%ADvLUb7{SWxhy{eSal#zL5c$p-MnTxH z?0M~wa2Cl=-a!F~AjrFr-=P(r3e;=hVlrX+t>-iujYEx-s0MuCTx)^JG|*yNHyR@} z7f}rF>c|F>-V6(~10YI{S;z@zZ6+K>*;)`z^Wd0M8ppu~Y9}3)ciOn3C)=Hh5x3ULKVZ9N zpD{+v2r$`-d#8YN?iOY>P)$;C&S_L4ms)sE(>bgqFKx_N>r4{MGKV=84hABswFQtw|c(c?fCZd(}Hd^JxB~uzDo2;6< z0jSM>mFI66MtSg|<5h1WqBbE0oq#~*l9`KSbOsr{`}`&eV76Zk7F>03vP9-uZ4Uxm zad@r!KM>oU{P+Gx+F@jk&F2jXSE>cq#uvzNE3*Dhrp^W#wlj*#(1rz(5k#hW4*4>t z)U4+MdPDGk-IojofCg5;$!Z~cgL#p+t>{^4w1+aBEieT8n4T@#btyg)QWQ5HGN}MW zM=6YXjG@$%xQajng1l(&I7-xSPn*Clj$$0t$A0vi)Rj5#$4NAH8VHEaSDYt}p|=xz zbWb?;^vf?WO~w~ewnovol?9TNeQc*cRE6pwGc*{pB|l`5(QFJDNPH=<9tNC+P=155 z+&2LjKdhaKymKs-$~;CO$bSpRP_F>-VQkM_DpG~&lz=DB@rY&bUABe#?fWAK_sEu_ zf|!nXh4@X%RoTlAegDZg@qgn*hqjkCO-!4M?bedI9fFT6+KZCCh+fa4W}nrCJzE?w zeyicK9p(~c_?o0%QBU!of_l`AIZu0B*La+aQ#x8(u+9Z;6VGq)yE!sKhu_yrQHDKZlCuWU-|uRtZl63*}P8Mz`fCrG8Ii2QV{5ORWNjLiMt%YZpvBDH==Sk*Dn|=wMtYS7`NHN`$IrbxH4OGH_D~&wh!nL^|M&NA z23;y9hg6J)88+Z!p3{#!r`zorqYfO8NxmQ*KI=>x(^(WKRhVhl0X0zNO94Hhq%qVu zQ>s!(h4PssD-F9D-}y~K{#3Uolm*AIUB=&9p_%q7r?0895xCh>aZN zQ?O7^y)uf^0Q0w*cLq2KdRqVm$zuoD*4gqW?PeTSg_hH41Ps+;8hHG1wXO;tzX5a# zN`G_Un|mwuZA|*R1KD0wqiado*A8vEX0Jx}QT*%;9c_e(3Khe)g{!zi+nuYgw`UJr z91FZ6AN0R;L%WK<+#RtNlhM*U7`t%hwohP2Yc3|Js;b|92@#hJYc@@zZa(|xXKLk zB1-!?idST-$E1)l)atduy8adjs2Z}aZG468e8~d>Qd$qE0C5TI^+BO!tkAT3>fBxB zV}PQxD%P7e_qZjpe%cjum;j}I6z2BExJ?u3eGlvVSJlm>bH2vGIinGEptfW=w!WV9L3nw$9@C{T~ zzR+a_U_l&Xykk0UGEMUZicA$|M-d{n=#fr%Z33;VOkPh)AkA8!Ao>)~v~dl}I8sOu zp~|sVfCC%*9Hs#>ci$A^8Q{@-cO!+7)>5dGz}yNF#- zfq&OqF>sN#D5o%(^d#iUO1pofOS%fe$>UaK{p}*i;1#{=J%*0u5Erl3p^BA(q=oik z{qO}5DP@HpZGT?#<>1jbOGj42tJ?fi7o;=O?f3<-Xr|Pbikg_)ZAk-qDk(f!D zNLe&Lf4+S!HSqU!-D-P2ZL%t? z?)1j15qr1e>a&n4E&9iWy6?ay!fn&~%++aXeR8v^vaX5MPt(VlyCUJ0J>s*Xb-j>#x@J|^ z@c!eqDID!d-6U{IQf&hFy}fQ|l&xnS6^Ko^5MhDkfJwW_>lZ^rqLvH4OOYG!&=J*l zeklPdNnt+sTsGiG@AuIfGNh}yZ1CxHK_&uw)Wz&_^ht?g(~yXgbb09Ul_c+>h1=JA z&L~`(YRX%1`C;(K)8#j@LkclT-8oyg&o|}$2)=HRH+;H8FMWZhWK_;T-(r>|qCp!zzZDa@6>vlzBqmv*XPq{>0lR@QBp^>w!a&g--%NM4W0-2*Bbu*$CdFCaDuXH8eKJeR0tUgQnt!vs<8z^ zhjCjCxgDP0Yos@3i-e=(utB*h-VW{B6%;E8_*Lao=cUo>p4KI*r-~MxH3bU$Xz{^M zF&_Gw^W`!zr@oX^#4@k^Y?jQIs6LicZfUJju?UaTuSO}`ZDK4473+G(MNkwEtG+Ky zUU!mv(H1Pw3|Z_vFo`L^S!tJ1z|G+N=}ENuXP@7}K!c0PV+KFO=iRx@2$|OF$|Rut zz+IuAgw$f6msyfUw?uX2dLjKXOG89!(VvHvth2~!r+W?IgR$YRTN5q%ZDR&*TL=I8 z)ThZ@9y8d_6ycY!I$o{u=r7zt_q(b#X|LkrAYI-Oy>GR+5g${3)~<~20vyi}CyrWa z7kT;GX~ZL2jETHLS4?wN(%g(bL^twDB@n{WOh{CQ6?^0j1zuRQPbIa;8r{Jmo#eK7 zYiWe11)o{bJHM7@7vVW2zJ%JA%-V2*xv|@f;=p{Y?|H1BczX>sx>J6fab*T)Om7I8 zX2E-s3oJeUQC;|^lz%O`9G9n4dMJMq#^dW}MSTWs&>q9D4PQ+Ro33AB3QslV@}%i5 z4I}v!ogVt*k0g3yM4C{uW@p?qOS|b97YOT2_RRlQU9z9MhU(ZEx9ATpeYz-wV`mB- z-?G#O9#CKd+=)*_fd9OZ3fuc--sQ1&X>_lXH)%_|>dSS|>23VLulRyvsllaDd_zQ^ z%zA%=tbAa@%=5vsIl}|<#X6wJ=^gw6d}u;j*oVb;mAOsS_OOQ-2k znW0uq@0ij3UJl-hu1fkQ%)b{@9)>lPJFl|AALuOz5oqfkL~`Jg%|c58H6$~SqQ9$v zqieZ%F6UP=5noqqfVN8;C zIO5nc3jC5yp-H87S(11SIvrGNuQCq`RvAMUkQA^A%u&E`Z`S@fDWQMhqF{Y=NB7>cpl6Is7J7#48Y15p~0EVRGL zh_(u|C9ak$+kQGX`^KwJ)o%t5&4mRX7e-$pGUA}j(~=U1s33TwQ0H;Ygwf&$P)Jnb z@gD=S6A)aETuQuYh2gU{i$;z(cs`AslUHu84gYl2GFNBMJ!oN&=@+3@6 zM&A&8Su|ns7>gPn@Z9W}p0pM-OONt_h~b?c2kE-W-vwtWh&$7s_`Ex6bM@BncZ_K3 zIHp=IXAKqkL=hkNx%g1G7rb9=;YJwyekdUzhbL0N)xInCrU}(@eoznuzz=*rfV{uV z%SFTm920k3${|erF@-JKgAoNSgj01i{?15jKTWP&)2)J>w`W|q9&Hi_zlkigPd#3gw6Ug(k=B@qm>P79 zytWWCZl3=6JeBc9>k{KUMoyso^(#QH5nR|Xv!NbTOkW#M^ z%da&n#TOo1>}`*u%%(CDlQsd1O)UK%fW;v%^bLX0-VL?4N-K%pI(NnGJBjH8y4vq_ z?fi#$xJ27R{?y)9I|F9=kiX7ez{#0%ILpW}jtNUuKKbp5>qg?8okX1&8Zl7-OJ>^} z({-H4aQJnYjAD8>c27tP)QhL-ONQNGp2?qbZo&?PNHE)6#~;GAA|tJRWCVeX@Rf8C z1>XJG!39b50jz=a%Nu=R2kleAU~Dts5zwIj!Q*sWZoebm;M|rF2B|W4Ksiok9PYL5A0BBJOJ=`g!{60niXlJq8<10YF6n zLcT9Fg9lr|o?yI#rvu;#fV-}A4QHn8mur6^K}N-eL*r`35-6xGtdjv9@zk#rz&gnx zBkEd2fv!lQe3J_y+UY~cs(yk%&Wu(i=4&fY56DwH*%nyIg!BS}LeiwH?Z9S$wtj!$ zD;g3Ma@n{|?~@p=?s<8f2d2@5NJW$a?@IEUKuv&Nh5#80L0Mt-hItjq1>hKo6sCd= zSC?pYW8oX4T2hAA3KuNq=@c^mYvw@`02L$584mdrsasDxt=mhE`hn4C61*J3B1CjT zUKiZsk<*UJv5f-ng|`N4Yx)*JABLE7+W_yA2?laXhy{RkfAPY@HG^HUL3=R@c5<`k+bvi#CZ%KLk%HPo4H;#Qh-0?vkA)@1MnJT$7Sh#XipN4;ON6Qi;BV zwztm{YFykcRM>MZiT-GAMpDM4dZwA(-URh`rgt_c^Ee-DHoVfe*K|_H?e3Ex({q@= zCr^F*bYevlX6K!q4LB9Hnf~dmTK?3jFo2%18LZwDqB*Z7qU#JuTkyR}J#Z%c023M` zAYZA12GOCBSXdkY9xO1}l;qqX566)aapZyK8@Ebs*q-r-X?}b-Mh($_!)k$y5K9c+ z0ucI@Nb6qRWs1Qf^YCK;nxlypJ+`h*FiPaG+SiKm4&RW3#e8*l!N#Ie~6JG}0 zXtG$Qc^zXWq?76V6T7!tD=7H_w3qkDS5enVqPyV{5WqD3NkficjORD?HssW--UPPt zRQZ^R{4dZLKu|kYS6u+74Soe<^@8aCoooY3qa@J6?y6)7%y>qROLk}$=w+pVxflHL zioaL|mNx}DCnvr&NOZ*#U8#aBGF;C~08aEo?#AlHF}2j0U}>zD62QQZDWe1G8p057 z6G}z^x}^1>HwxMrg!cwC-fG!VJ@ZV*AaB8U_qrl?l*Wtuc_iGFT#6?!Ux{90N=&@ zug<|I7th%rX;E_W>>=Quj51LcPJiAV3*X+K-o8S>UwhmB!sHHpLgo}F=$~6`G%RJDH)P1K>6Zf zy?{=s5JjTHg)O<4-rFbB5!nFP8ri^Oac>UugNidFo43hT9UiIpAT~J$FZ{qss9*7n zS&4H{_JaRgW(|6dQkw(L8lvU?i9T5|DCbPeQ$WAMZ$7+1>m{3R1Mn<7URd#AV9`56 zqMycu-`#>wYeN?u2qHl4IipV=-{93x0*;g6+3|1^K(C*?cSsOG=AAkwfFbKAR_KWx1-nOrm*nE{YdLsbs1W7s66Y0wgGnvo_W3{&T7|GKi zobA$zSm=f{K?vgMbTWgbt>6R+v|1p&570h$B)D2ol_AlMUC`CU8VoS;eG|Ho>>%;< zB@c2yqO5KOAQ&d$ngDGv4?6iwT1khdOIU3JWCIVpCVh;=;0zRDo^qt zTQNGc!X(EWWyM1@ACGF06&6^8sb8Q_tUI3mA>7$pzlYZU>6T_8&4U~xej{t48>|To z-ZBUK&13Rf;cH}lcLB@+pjY(SI-NRXWp7QiwqC*>#+SpabW#`Supu(qPiKBXfO6MUhhyZo!_a{cC!p8#~M|58Yh(~+JP3l$G3-Y6Ye z;n8_)*i_y|y+u$il%dM`GZG-l%qX zE|V@M$F%!+>rPt9wPzvP=T_c$k1XF317n755x(i~lO^y%JlsP9|74|ix6|ejHn)L& zH-V}@mVN5X!JJ4V1cgq>5P%O!@`+^V&Nnzk0QbP={?cy$UjV$H-PI3(4*dgf2X)5M z;hnPXGE;vC@Y^m=G4;CJkR2w*l|D4xX|JNFZ;@ zshf{!b|%=2Fj&U)eLIG};76!oEbLsJj^>F|dN{-ZvWLmOw3EF4WJ$%k!0)g!d|u%6 zP;?FP_EuHt`n!Ej5Z zr&NNC5&9T>b3yDnZjIXx$9%|7!w$!Je`qg#WpVT)BOY7j31pJdi9ues*X>6=XF#m{-p0Qh!T@C4a>ocY@ilET(zAQ%8| zCzoy}bn;wLKmasA5>z9b&Lgp1V`Olw1eUBg4x+;o_lIx+T6uz=D;J;?LB-rA%vI1m zy8;nj`d%r3vLrAPHo*6dUTwXJ>Wpqbll>Pxm?Qb~ElSr4plt;}xb6?p`u(tl!*EOOLoAH7}^X_-5sF7V*n`pG{PD|WexOIhh_6Yst# z$}4&y+Vv|s9ssDcN9f8yuTWSNPe%%J*)T%%l1~Hdwzvbbo7?LE-784x|4;L* zo!UR1N}p{q>rO;KYy@iHgv`aAbL#7qw@1~GW15(0&Fs@Qi(Q&Z@u~y~bWI5s33wIR z(wsuS62`kV$Ao1Aw(`%6WHDVkyJ)p31eNFg(*bL@;^st@m~(yJC-?mkM~BV*j86RY zzV#E}2@@7X{UG+d^+5B%;HN=pThrmA*Zx_i+x7#Y0e-!_4TId@b!VhOjK0bU5 z8>)lfa7N6@2Kg`8(5JdM@~LFm{>~T9U~7*w(XvCz(`F5j=e~E&$L9F= zd@uPWVAJ|kH0%>L12IuGQWbls`Z70s;qGkhrRqyI#ZkOtH<$eD4=3P78C@4Yt=>rc zXJh%!KdLX?v%hToTrt@beIv3q^P{j}UyOC`oh%8(%=f90UsK%KWzPF5D0~12t*a z7(35y#j>TINV?01x%HU*EfKHS^e`=M3Q?IaDUgLQ#VMP0)fM?7VmFzb^LY(*-WDfC zpNk7kR|GVR4sOUpi-B5JiPg)-nm7&sQv43)`DSaPQ>qDm_HWjs%(}>kM^3X7Kwo@+0(GC$yE%wd&HC1~ZmHQV zF{IRLAlI-Y`FVoTBkLRD;!?Bz$v8E)CPv&NYtam3-|-tW!pCO3)sPaarZ7EUeB=1q zL+c@i@PXyMp}0o_SA!tO+!{D*C1#D;z!Kc4RNy|VYDWJ9dvTe41wmZPX31ObWL~yu zK$`fInbRTV_VWaz8Qe!g3)t{R0|iVjV%VuV=&7FY4ShLntnSphG~`&A8riCGGN#?%iHriBD)%d0AQr>C zaN;lDAiBi5L?f+)&W5kcW{2}Xt6Fq)Lw*{L)K>fUH!F>(;@fH$ZdA8%eP24ho1QEo z4)+3|IKIej_q7(Z$!od3+pvCu|C)Qr?bO)W+o1Zy!pi}!8;#BvT)!_qZz_;tR3v|V z=H0AdshJ{Q|E?+Zjy=1&4=E{FnL7C7_vH(R12uK@D`Hori7~$pe#Hj^(WAtWoCSB%T&gNMhtg~5!y2kVt_FZwKIgX(qNWRb4Kt z`qz$z8GC7DhTLIHZJ!jBSnzh>1C=^>ai;LoB2X`+zreB6q&+g10uOz-)?ouP3(gVx z$^h$WYr!R_iYf46aShIA7*Np#4>f&nPsHWHQ0P^p#wKq zHCN6=_p!>P0hV$zCVV^YqTGZ7M^bjlkD3Uw&x)BY`@RhMrb8l9w2hN{kpCh8xkD3Nqc9nLn(Yg*4usyi2qC>!iS zpAYpUZfBO?U#8k$tD4ICm3<+S&$7Ev|I%y&pcNTK%^sT8cP0j_h2E&|-h{f<95cQCV#UjmW>-nJ zDYyE>vG{y@kATcXy#a=|esr+#(RQluhzvTLL8Bn{?&zO>IGNpAYxa2NhL3$3bWawR zkUJU=udnCfg#&GF^s(z*&5C%6ZQIFJ^ApXv)fdGMZEo{%QK-95t-o+6K|8*F+>Faz zk$^oROcuIZ=xLd(?Gwxt>hR_`mI>qs0i1A1HO|M%l&KL?*{oADJ!wh8s`;d`bXvp; z)&#+gi0n9>7>uH;ueO@sz?xalqVnD!SPg`Nk-C-2*l(;VmE2fsL!pf#(Mq75ylRaA zYpWWq*p^Fgc*G}hdlft?gH^)9%jggI9NXJ#Hlg3*FsqzG6Px*&L_>M4X57td%c`{n zx$)o|48_1!rE;RUzf~CSIgV;~isAzG&s7!L_6jr`IY#>HbxLmerVj3mX2`Vc^!&fk z@-%RBbE<@GX7yI1VsAWx1n9tKzbQ+AdR+UQs|4E3(0a4^&7~=#;?+rQEuTiJoBcC$ zC%=h`6Sz*hFLkmmCjEjCIZ0R&KQ;@`tkAy>%;~T>bklYM07^Te;mwL{i5Z*NXG`;{ZsAYAlumEfogh`5v(-mqe)Wjij47zYj@!kjZJThE%Ojj15o&ejoh)Vz&0hs@Fr`r^{sx-+vfKPOQOusza>uu&Dk= z9f%ioiulnxR)<6s_)fX!Z?=Dy$B?q%TdV_SQRMYNM%;m>=|W3iX1U*9q0UgUqN$5^ zsekLXPAl%9*|050ruv2bDdJda5=#$ofL0S(o@U*_%?GR2boxuYt>!+1e$SG|>;IZl zKC`Q>jvG01eK`ss`yeR&Px4t#PBF_;MI@7?!WPU9^`Jtr1g>t*xF81lkF z<>8`+{b!ByVog3N$SM~!OafB?8ZC05mQmPIiSqMUwR1bE0up`?qw`- z)H$(gFTmz`izTQzD$Mxbl48r&+49WHKYOz#md)Bs_A z?}XZYo5H>w=gXOH&x3J?DY!;M{j0X8LpUHGpjK0c=2e#LJ4iwPZ8^j7X(de4&%GMc zeT0j2t*4n9+)M+}B(ALAwMvo+%_cskqwjqU`P!=~t<;nqKkbsVb}`tGwpuUkwJdLi z>zZNyCA+T}?Ny~{)vv4aZ6f(Wk)F3eP7a4;jJNJN>d0vPwH?)_AkX*hAd+&n(?%KL zGUyR9y0aIgnyu^c{S0W5rrB;WG^*DFR$E7cnn-Gy*%hq^t@?vN*;!^Gh9H=i^=&ce zV6z&R3+m}N%{;%qIYa+8_mroh*3!J0b+&FW(6fnU{$eMT6L&28GHy_AkAOwFnW0uB zuhq{~y9ZYLCi0Y4}HOA6zjcSfg=nnU)I!FsuxxJn;7V4m(+PyET#G$13kZseegVs(>+<{MocVLknR90D+`04K_wzi}0(k)~p-BSk%$7eiFYl%q=Qdy7 zEhS6h;AQUc-fV>?p?1}XX~|dE1zAfb8}_BdqeG~;gi=@+B4?fg<#7r+4yX5 z7DfIpO8KEDbTC_C#0YT#1#6MSo^=Tp|=Hx4a^XoGzD0?cL0*@ECIS$D&{%widI}k8Gwyd<1-iI}=nGX6Qxtb- zh`n5i>~9R@=Q|z;JVK}XHM_++(Pn166*GDoR0@z@U~ zQLOmrK1#HDXj&r$)=RN0-hw(Br2RK+)8vs^*Fbo&Ab(&-OLD==)!lP~2m8KL)|ag? z36Z`C@uRcwYhChn0+>~#Wqp@?^|JJc>lPIDDz_|_Xl0Ig|J7|#A6Jse-GlGtc^&LC z0%#)7>^DsrJKweva)x9UFLo}!YgB1avMAIX$|f1+nP-Mt-V-v@)8I)X-nVCA4Bv7^ z6W9vWdr8*&g+%^?X5%dsyj}u(K_d)Ee6E#fwXqa#cf~br$$vGGUlhuJ$%c0cEjzOn zx(FG?GB+Q5pk0X1sK&wXWT!vGnq2r`)4hdQB`E$}M*g>qv@8MBDPSw1Y&-)@C%`%- z4%$l4ZkYgGJJ@1Zw(_&%=U2EH{qPo{UnL5WD^xU1y*SUHOwB2Ck`Q)stXdhkkI0sb z@_{?qaV0(}&+Iq%E2v@Z1JC083=!Kj1tJsebI^fg$YJaeDJvAknXQ^6vWl3d(6?Oa zy$iV`g#JNM%aK5+7t|+~DJ&X%QN~Y*Ml@rPm$xAB-QM|SICdLsO;OlkSlz*e@XzWV z^M-BXa)P*|0Zfi{0^A)G^9Na{7+d>Ig8WVp9KEMITMB#m-YMf*v?5G1<(z`+H{I5= zMGFM@y?TEc zuHr9=zU2Y^f3`7M>x1nju*@)w>0M0cztk^E)Veq4lgExVbwNY(P03^O4@LLxPz)Te z-ZEnQy!uuMT_cxw#kF=_p3H4Voo{CM$?sLYgVhP;Q+zXiv+%K!_h|#1@7-Va(!k<* z5By2J1v)*WuEc^lCtriL$aK$$uZNwHJEYMMd= z1=dMX9@v8E$3edO9vr?vyDgo@8?u^)<$DEIEidGIWtwzMdfnCFiwyU7D6@Fd&0d*P zYciGpZGo*Q;6xNS$((mQTVWoBsHDkH;GlPxamqv}r%Qn+w7tWHKV_&ZmO}?5x3!7T zXA;G}HP!gEE^CqE{$8aS${w@_CodfF%b}eF0{kvwV2vnccq-H(JQz#n8f<@*{F=}71etX()tYlt3 zMRfThfS-SCHrVSjFE}v8Vl2`WM#C8|vSGX1^p7~?yY1K)LfFz4SFXi}q|fX7Soqa7S|M zW?4*McG-<=`AMkuQEi($vZl2q2Z_!ki&V?L)Ms4+u^UPI2m!vqFr6=nYAlh`klSmz z;>sGwZ_YG5tq&AqMgFvL-z{@_QdEnox?XTMCEYBG*x) zdvUZ*6AXCc8h8)~%g;uO2y7{sGnav8TAntatQ&Z8? zvw=4p6{l$mhrTGB0Vsr@YOhHNy zKiTT^1+#usTXQ>GK})mzFI)U@oadDXu+|c1H<08W&SS3b?j??5#QweTc2sQXhUOg{ z(hIM*!(e(`o_hRTQEhUNH4K}h)o*tx!Y?Sy33TTO3YRV_d=x53QCW>i_1{pM^f2U= zOzQtwsriX+MeDAQqhYrxFuNA{>y+pl82Ka1(tRrV<_@XM=xhy6zOn0LIR&c543r^U}&RJizM7=k;#oB24rh4st8$rELtX5nH}3;FhnE*yME zc&h~cs&`qj)CgXOg8h&5%Kp7giP5dKqfQ0%4tHJDZL8f`Y%Yx648z^$I4~jB9_R z6kp@2^?UvQTT>`3fj^ZvicrviQApy-__I#=sv*x`y}U5H|pJ3lEOHkl$mlmeJ|{+w??l@<2w(HD}WC>%ksrn_g}*vDfiMm z-f-X(MN`IhUu=KaeES61t3%Bp9Fct0_gv9Si&O8Z4nsKsw4sdeYk!qrnLgs2wlVjd zoZh47n(=G*`Pr*IuYJxr7?&D2+zK@fXf%>`i@rBCb*c6Ju?DoiOz}BybUMTLW`)ts zZhO}?wv{|HJe)5#|q)L1E**E#u&fQX_Ms~y}%B@d^ z+zP)NqpR4YHrL@_?cQ;H-7&PKFVA(tcxrFH+?3K3^VbgAvd`vrebV5m(Uj_2G2JO~ zui!BCPWr%d`kAl~A+6y>PtQCzno#KwHThouXb+EQi+SL9N35oMcaFO7^8MP&iI8Z@ zUr235)nqk#(7>_lSJV%Ow@1~w$NikOi0BON-^q9T_+G_|QcHiXI;G-seBzO`_b_kt zcJRx`In}?703bT4yP9j7YBuO&JXIU}l}pt50jZPOK6 zi4XXydwbnFDgi|_DWUo)fi+K*h{$M9fl?1DWK~I7Gl@&|P&-lU{uXtZo@#P6L4f6! zQx)ZpA2wUL3YIpPmOi@5tYtmAnzT1$4pfhxayQXR@Dn9ovuj{I(qQ)ec&mL_zI~^URjfrQ8$W;L)aRX5LYM4 z9wp5}htbi~2U{wcYZJ$7>G(2J_2`5k-h<>*-O>t5!r;2r%jGuX>hL7dx)w7&4pPgU zp75-tFA{TKd|mE_RL6A1od{;WZJSq1XO4m*GS9d1Dx#)C68n?g@<4hKN4nRwLPIA& zsxh;Jrn*t5eBPBOrPh6~Nqg;j;xg}~=$lsbsZxmgdA)v~TADZP?kcbNfr)mw!H~)2 z{MiY!I)>Nr(fV`Kk*3%8&XKeh4Y(hP?`ktA+wy9z_1g}aGzBT&L|Fzyc+rZc_?{(E1F>Yu{RR`-ePe0z~7h96#GTX9Z}3o zh#@uznkaXQjKfp4=);B*-3cp`5+kI4Y_>&2;gI7fTL)= zlImCos)%s@-{<@Oq_&5Z(ep0Fy}|K6MpR-+A)tzwp$22MAm`L36~F0TaOY~+1bDZ< z3)@){{c`b<(`1lh>K>P)GfVX_4#$Ai5}+W37=RvQCdcuY2$fYO_0hBVq`sTFe5XLJpNUaqZZ3cN6?3Li+}C zu!Hh}z#!26qakOKCgAvEaR%^P|Q&mp&lKbeY=7|rU2g>4#s0n#I z5yllb< zVs@_e(DU2WspINf8=%vxLDuKF`?WtE8#^cwbubeQwQGc)j=_~3$xOK37Q@(Pa`e*b zinMWICf_3&m*XBYPFC&B9@}+`gNMa59LmWAxlrPYj-5tme{ssSIMS_frn~cK`H<&s z)w7PMIfKMI0uRfzIAqN0_=x!^q)F^m8zV7Cd*0j2IlMHVPL+#T;*Wl8o#|oiAE6kt z7QF{b6;HJe>wTb(&C-I4GuwytzvCvHjZ*V{lGx~}YOcEPOi(ADXL@~OBiNIra&y3d z?gN+(8p=^+ged=ez(o+*FTyO z_(!^e4X)MFwnhYhTIE%Nz8PI9-=83>QjAHCN2q8-r`4Z1s#$2lNPW=Z=dp+GktmvR z$uMfx9hC;BBUZjEs+|lQVx|i+Vp^^$P_kN46XaoOZ%+Nzl^FbfliRhB3NN;Oq>{7f z9zhV@_M)?o-dTtOBtMeNW2$+{Hi%!ecyLqY?8ATd3N@4b6E;DK4O*D5T;A ziaqXu^-bjO8xd=Soyox#8Ww69H=z2wc^QdR{^|Rq@CghT?^$Q@ZKqDR8wI{0C`E*Z zrXCJ)=i&oq##N3VQ#-N0+}5*O;e>aC_TSK4+Z0;`hBrBAZ?Z0^1f~Rykv_;jw&XC9 zuA;MDamehA;o?ukf+)N~ibfE@>QS7sO4zbrcd1#oaRB(=V}dGY-V3g_Rh>TI#S zTvlgvroqWs8^TdD-OO5GN4;8`#Q$`DgtSH#k@3CNL%Kz9X?-MK^%Wu4wMX=4$)Z^8 z1Aol1U*sHtff8Fw62~P<#o_(&+V}Xmt}lr?vGikVfS%Eh_`3WzvF)lQ55fOe!<)|R z`mCjv2};v>e91l8NOztFvd&W9U-Vp7yY69*TLiH*4A0S;qe85;&a{^mTE*d?E#Aw{ zQA}wyq?Zf7n|;WJ^|{JkG&1(EXt&L;7K-hz??IzY+0V*1>7($0FC^0{ukgsi6D?n+ zE&edV!WX`VWe^`MS*pnuw2a7J961NFBD|Rw9cjMZK`n?ly}gerTwktZ`k9bE$_FHb z;jWu`#2R7MJI@)Q%TPr}Lgco)Qe#wR7Ikc-t=oI2%Q6=C5TSUd~ zRxWl;o9~lGLRsP64U%S#Bmm`+-(~j7G8An2)L%hfS=m~clkR**bJe1>zPNs!ZZYGDLWLsdr76 z94arrBsY2dDt9pV+JvcdO;wr%=+ET&IKhIGc-NP?ojsgC4M5Q{)svRISyFXrd_EDK z?^VUUE$$Csa=o$@vo(=)!qA|E8|-~E0#nof^6pTbvxvbvvRgYM(T;TDJw7=28Dlw3 z)oC3V>>l8CYutR74)n%Z`a!jQC6>$GoR{fZJ_8^MdcfbwlMWrsyH_(H;(hJ{1tuvh z4;@@yEKCbVPF7Vcb`3q@gQ8Hp*A}Ks3Mls=Bz&1WPZjP7dEb(F(^b5ah++n*WRC#v zI;+1IOGO5uc!8}2L6*g0LdofI=b?$5(-h@$8)9ttdp=am(#w<_j+J6jJZtsf?Z7qHIMGnvBk(SYiHstIG zkSGG@+eeNaip$ghx|so0wmD3MayXVt$cBVMRg$*AaT<_6Q2tkz_aJ{T!G$xtFlWfM zf8nR5Kmq*RC(20S0yvXV6fA7qH{u^iZlOEn1~Lb)9v*saU=vt1_FV*`Q!2yEDya;Y z)cZU-Etdvs{$JarMW`wTI`*gJ(0)IMzwh(H_`HY>cOlFqh7V0#5&CG}NJ`aBI0)I7 zd?z5gDvZ$-6UB?!_W$phun;;nnQ9y=%wa-Xl4-o(<3XNI!{?}uXBTo@(DL~tjhvg% zQ{Fs(4PMBGd$0yCyo(cF#SOyb-;!5s;p-?~2G)0Rg2v@;de1mZxD-Nl=eBHs8OVPC zbe;x@ra_)ojWy*Lw!hT5&!8Hj<>`jv(m>8I4L(!_`jC3#Qwhjuz3{pigmWzpu2NX4 zC}$FKV+D^y*N|b$b`)lB{DfVk%ye)zWI)L&qO~YS@F)_h`s}3?zNr7tR^bl<=o_wR zOG8ogQ!ldB=)9kazo7CylNTc3bbDU8+M37E1hon$ER|Y4#tJ}2|V=Ls)8FGQ|F>cHHd$62<8GO3yGGnfKO_V&JL)|e>-hk1j-ag=-k?*swjG<7-a zT2&28)04_=B7g!m_O9Wxd5sJnar4?GIP9EdYzu~`be-#~QSA(Mmv~fH)9!}TYF{h> zk59-)bU~5@-1`&Y=^swy-r~-4$oBf7*GY&QAAlDyHo+)Rzyi<@+TqW*_VGG`j?!)a zP#7qDwA0O*Pb9oy>OI}fPbe+=-vrq6KW@zU!T+rwd_8z1hnOt+=0n4v^#Sz&O5gW6 zUZ#0VRCX@!WKqB@?@L=_F!Ww{l8$KAZaP?h@@3&Uzd^yjV+K?F@5`$GTl<~QURk8_ zlA)JZwrVrp6hEhD7fg}vjmA^0i>H`mzNwu<+ zmZ1IlI*2loLoEl11VEqhHh-qUC!j(k6EwYA{3NoV^Ull~A9Q2OTyp*4qcY%MyY?ec zYu{{O0Nyt{VxH?)pW25t#k6#}b_!csug?b%|q( zv&Di*UUVyO#)tQ$gcsuc6zB;F9_MA5zw4?l-adFE5K~!aw|D$IRsTPD^NCBK0H=b` zt+ySP`aj2cM{j~+DgE10h8K#V92`^#g^1pjkm4I(gcE3hNGKqi7Ux``Zy{_jT4_Ho1ckQQWuuFenWFee(Y%>yEQ9+`6E9FEWGsWsr14^%a%Da@gv#_C zR|ci^fP9lre|-Zq{Q4w3SdBj3JmE&VB`CV8!V9Gk&(>f0UIz;PXpxdcx)cuTi6NY_ zH*HQvWzt$gtI&;2ISN^ze--meX=ZY|AmQrCP{u8I+WXq1nNClhE57;=JuE3@B&;6ob*2>Dr9V-pQ=H5|xyjH35Ecy!pp2Fr5E%p20i2H1ztunrw41!~NCz zg2J@ZjoTSVXC#LxG9hHp+vF{vUgjM@cqLWn)CPmBi*xTE2D>bPQwWeiLaxge)y_-~ z5yw&%mZLEf`%xUa2G4$rL)hZD2tn@DaGx1W+`xUpI4@M-YCOcDLxE!4226Z&+xn{y>&EaPjGR?zsBW%1|>m`ZQRz4^RPbnnB=W_%#0ehoRiGR zvd;q_t^yrc*;6@*0#_YiBv8OkLy!=O;UP?F`huLjll1V5?gHCxaNM>yt|fBkop>|~ z`1UWk3%lI7Dy0NyZlvCmi@OjR1`8tj8MzxxehNA&K>A&qP8f_fK$;w-A?L0Y|8}z} zN|t9{19JX?c(nkI6+Sr(t9mH77rDz_6jntofG)}@T&?an{!a#1z#_**Hxbz`4_nOABmZK zzUE(RKX_*JT)Cxxw_{xVx_Of+^`_o?aTBcEBE9e8*v{9?aL2GaHgAVUYsxKR2h;Lq z*Z(VTrGn44?Nq=k&WCk3dQ}`eo8H+`@0<5G&!SpHU7Q{bX|mA9H!ZKqo86JNds=q$ z*^f>IUCI_*9N_Q|co{!B% z?H_jtP!5Lb%PYluKH#3{79S5;PMCvi8Na_l!!DQa16tbc6Zm|=lPpD_E9gyW14 zkLz<~#1QRQb7wsDGdJUMw1*1gb8cVkii6nR zSh_r`b<)P}dPuM%yt;4l;@!qzxcSYga|#wOpm-(ceOK%YZ0=U<_*%_T*Mh9^nsYGI z+KQbC%Z$tMY53llVFimm>X*rb`jeah*W@d~6Gpw%5Z|bCQ)P!ux~I-u(Jd<}%Z-YA zFgs~}juEVGaXI04rs=f`k`J{Zi3GFaUI>0^&dJ^X*t$c=&Q~wP?|AD*)0Ne%8$)+^ zx))sKcFkJ_JI{?4YW4?Nl`oX3nn#q1U&ac&4JJ&kqkbDRsoC?EowWY*%Ifw!?{fVS zD|5}S6ApvzR~*c4w2H$lUq z^NywKhaCwbOpI_?>s{BV167(ypCXy;4+^lhFNKG;vOO1DJ(b)e8#gp=8YoD!6Tu2z6dth5w2I*{F3x4(>gLUln6~04r_8U&<3-qFgId@b}kKM z;kj;}gd?fXV8ZO3?(8E&-=L&@q~>e3Q0d@7W^E1Q-Bv8^cPAB+(ZhN_W%#aDvU+?C z@P1SxpYHbt`9wL#?dE!M@;8pYVzy8|pi6;XAy#wU=31oyV3Evv^$r5a5#LJq54Uz; zivaoaZ5)!hAdcMkp=9d*R;0#EQ2CqstZ}m(KQa3<4UwL)577v;{+CqsFL3 zt!F{duM~d6ILC-&I1o}(f3Cr4)XW!9InmK%6S)GUn%L$AO5@C(Ror-!AQ<(2Jerd^ zy1Ad6f3JR3HTk56@s0qpj~S=R(EyseaY7q*I#s$D0p`A`nxhHEYS&wfjI>(v!gW;l zb!mW-P3)BNW42WE=6Dupr@$Zy$)5ae#N=&=MqpO`!KiW%>xaGaKeI{b`kDac#Ro|_Fx<`PGc4po?Q@HNI=9< zS$bJ@Ik<&!#5^&YkOqf)4gUG&DME_TkXzwY|cZi?*})#CFW=P$!Xp`lY`4D^$Gsr zqkC^_t`)Zq5BwS=^%{g?Qr00c3^Af_iKFqA35ov(Q*g6w#?{RrJ-0-ti!u%|b(|IU9 z_A_t7$cXg*yXK@M1=9GMa_@fJ=8q%98i)_a5o_LkQLh*7nnb=W0_}`O}_!x&9LTfX} zx~f-G;y+@~g5tiOs;pyFH~#7yEx!4WcG#MiLDho4RTL4kH0zao4fDRR@zx$JsOAmn z)l5<9CAab|T!i}Tg3*JYMa7v@oR2C$fo6Rz`TiRQ=qA*tY5!Y+K+oHC0QJEt?Xph6 zatO9>l;`y@+98T@V(7h~ICHrU9`4to_Ko~#f6huvq7_G>V8H9vKPvPXU_A07W4MUA zMN{jesn?)b8Y48e`)<8_JdJ#*e>>mwUx3m>dhkQi<=_?2E4-JU)rkxD2 zU#s}EvEGo*OS11avhZK$ZI(%}Vf);aJ~a4)iHUDpxEC+4oD}p$W9Xpt36} zpbxk{Nq!-Xh|HOCk0&AiPKbeTvS!k4fMvFawZ!HZBK3kf72+A z5I#3rGH4hFQI~nTX(A)0D3s}`>P<3lWvNT{0DAY?Fez3uRAhwmbfn8^+#xU0ht2S0 zA)RY&A7)7-Ta~aa!sQ$kJubB6oJteVY(Lsy$|m%Yj>mB<$4QowP^&b_1$ake|- z*^F-W_cX6>Q{`4t6r6+j&zfO8AFygyr7L}l3shhNu`@lrTqiMzIasrFZj2A8GLPS9H781u6~UzITwlHpZE z%{8dt{sMNBCf)*@`83EJe;a%?C#- zb(F|rm*n+2q-q_`es4IIF?MHpO32|TVcp%>9GB4 zVvDUjyX-}~ud;@Hnph*?HN>nQM(i3Y+btmKb6JkGVLh`6n?Zm>|MGuLr1DhL2tUw6E$QUkhvP1>nQHf@oIoS(R50@D9UiQy5u>)qoacRn7goIxphj3p! zp<-Wr3N^1Ofc!W1s6rN>g!m&yE=@{_8yRt22yraIc=XA75OQ+`vhIoG%(r6vk{Gp0 za-m1>$$dL~=NtF-C7QarvP(xtHU3dLFymTn_Si^4Leadblgh7xwX$0+J}xMPL?aXN?E#c^1I znA4Ihopj8zxKAI5B%pLCFW zzL$3@`=Zs@-=z?1jggDET;hNi@vZD35*iwmkVT71=ZkGN#0~;Z%cEgDeeYGq2wSM> zf$C zDEGUGhwp~0bt@uwg7xVtOQS->8wB8mM5R0yG=`#4MbzpdYVJ}yG{pK6(!O2MRX?S4 z84|X}=LlG=@7I3xrJQD6p1xH{j2avFi#V&}dj-WlUsBe$1vnrP$48OIW%)5@CiFwu zc&J$4$qIv5w3AWUPvBthKf}eq>=AN7pM6@4$oo#`%aDc*x`(m5dD?LT2OwvEVV?dk zW4%q0%{Nx=3|sT7h+W4r{wmh`3OM_jv+*vNHvUA1BQ=`j^($e_t9?nUy^>MYf0`hgKm~Ds7&dab>$@eM83$29m7t zvN6&%aE;i`3PCk`ZQ#d_Nh4k@i!rB*u7neH$33LcQoT8%`WjJ>N`x@jdTXNpUJ`XR z6dm!yM-ewe7g!E8Bulfuh9l=LItGknsqHyD5L+*iUhuD*?iO47$sUNUH2cNgm-TYw zclDu>c;cn`-`wl&y!UbmMfN^8^-aR&7<{y*lWe}q_Ut@+^Sk=*3@sVs+oZ)07vxdj zP|ztvH#db0P9ve9f>8pPjP7 zR-F|qXRz@VUUyEiUGXVl&2k5z!#!3VxJH2CI0xT2Y=heGQ^LgyWNq$<^tM=;Q45#;1OZ z_^mbwmu6*NIDs{5wK#WPN_C?S!#DmE+n+wET~#za;u%T!x(!o3$Qk z(vamteV@X}L#HR3{0>80JG8f{He6XFvEXXuFTW7e*v_hq$k%t>C2F%%x(uK zC8sU&tJUE{TjWDoWSMq?eIDVvj=LkDZMih|#5;GghJ)V}4+_cI1|VX#_v;CcD~9@B z`7>F-KFDp$_m;MUOOa&lcVFTzsgxOk}lja zZKx^+_`H3*SMu6$Kd)48JV5u<ANWI&Vk~a0rr~KJwE4JF|G0uv-A3dyX z58AUq?YOI-akeVNZ?H7h_(c_fsPpS~OJt3ho|372jaI=gTyv? z4#p$ToIYG|-A9`-Y{gd^3wV00Szq_6#uXhBW@Ol2F!2k-wykw)oJE%u;XPcxO+JWw#=D+P?9G1vNf2+;YO-l~*sM?> zivol(iH){4bk3;GD3LnsShzvL6?k2`Fmf&UrJ5%rSTQWCM6lw~a; z>TR6Uk)lT#Vl8{3PN+y7L#)maY0Zf+M!%}w0u0Tq|%foYcjybzm}B_C~Ijv+9Kyqog@2%D(Mf--gzSHvW0lpM+)gMyKs-~ zsC$8uC^t7w%EftG)69PB9_DTpSTB(AoA{X(w&hn6GvaCYp2Brv2HVWZA9gu3<-|)i zTH)~8R7H8{`<_=1{UG$xF=iw8@2gJ_$4vc3#R4L4$$<>A&ya4-uZ57!arO7ir6^)PV zUrl`M_%Y1#ice`&_1#aGJE(oW9pdKa;r-_qR-f#=_}`J(^yD`aoD;<+o&Eh$ix$3* zwL0kiNwLWnK0ddi4;qO*b8qxqLDVhX6nm17p_EzV^|*LQ+}@-6s+&5gEIk3c== zG68L|>;&2|pja(c958n2ibeYUwe;gfuli2x9Ti-Vw$WC!1T@9}XjF+x8f;goiuuua zg-+N8UiNonSyTl3%D1aTESy}gjPx#@Dv#_j2ccsqdB*C@RryaQQ4141wMfRvbzIOw z+?Z-h%Ak@(250 z4~(}BQmy0{f?DWe^qCIa(Z_GEMY9cj)nonUJT6D2)vc>Io>cNhGneytmzSrz#|D_2 zi#fNMlDtu?>zypG+OXg_E{euhGPxS}GKr;>*|TkQC2AV$ply!ws@Pf}Mu|T$d%r~_z;YpmLJypuVF=>_5TTL%y4>N*tn|p6sNT4?*w-}{Lf{>cW&*hrXshWTgQ;3O{K=7i};;)H~x4V&jH#}s6-Z_a_YS~lOAWP zP0{@JQLTnCF*D|zQB=8vm%mrZT7xe!HL51} zS&UXL#ETyF(BHnlxz`lKylRK0BrHBsV}>R`s+)Tz!5^45gjh{IQTY8Mjj zC)Q^E8XT(~lKgn0MfW9o=w;3ZB_60h71U&(9=`5bo3>VZ;)?g-SYB)LnVX=?vWHX8 zR?hbPFj9@`--wrY`5@B(bM`Av_2@pvs3jF$5ZT4i?#<>| z9TsC79K%QvhwLPsP^3rkNNDK`XeK9W>Ked57*Wk2Ue^S1XUhZzN~`kE+aM1 zk;{2~jeB>%L#t|9vh^Sj+rc}!Giuc(9_H?KO<fU6oX z60cG-;GvSp%8My6G3nGOS5CxQ{Mpx$fw(qm*DNYNgMXvhCL!BQvlL7chqamskMy$v z2$FpQGLFg7ODI)H;h#!x7mgn66y=`C=SKZq+(oExGaz*WyJM0 z4(!l}Ma&O)YM>tFx{Nhb!BC}PC$2Fx4z?Go6R$B1 zT^#7KEi75Qr8TfT^3Z-1B1Rb`eD&p;1q3M_!Z#{E;FIZU@{spn(J_;H17;qVx>gj26B?Rs^NTFS3GkzRVqM#tb%lME z^O>0A9c8|!9HXSFD!L&K{q?Y=o}IKEw_fkR<#B`8sl|aZAnn1@F?>&VK~&X-R+HVS zFQ;q|q;0ax7OBrImRSX13*ri357^@Ac6C%^lUf@Er0Z#$e_~G}f_qckRX4kU|F_OU zH!uZ1H)*+{6hqbYE;-X+BE;X@0GUA;HFsjcg8k|TkLbPG5*(RR2eKX*HYghy3!K|CP`|kVvYAc26oWwqdTc+2CH%06;>Cr> z4R~4jQow43GOGVbA&0XD9J$>JYy%Nb#Q^KC zzlwLq1G`Vf{UWMgrMUO@Cl1IDB7s(fK~HIFaR1C78_*psyO6xBO;`XXSZKsXLd4 z6kQc>!h&fFo_qeAvjz8-j#0TYj}y+?mfk}Yd?F{QRG~%>JFX}ZSSnF*Dn=p|8kT*8 z{_g@@b+4GE<&Jh-xfDNT9A9_X4U0Zg_o%+{@$op743SFIioCrLB-Mu_gdk6mRx3-z zq)t(TrTCSmw7@~1iBta!2uz7D4vh=#A*%Nm9Xu{|GOPsPS?F-WOj2I>GzKw8&{Gou zM+i_~g7#WnOOy~PmeF$}A{Mfd3pBJ;lNMiu`Yxl}&qDT?qZmbVwI@+2todCY3XIj3 zU8#iuP^YlEn?+If7U~vY&E0J6nvRBrW!|QBwS`E%`8~31RN*`pX+nze`A(;DjyXks@@@r>w6 zQBO&;p*tHWrhzg}B2q?GI_uN~AN~nwEn=*wN#Z_ctf)2quRzh%-nSwD^{Sc^*7OYI zRzKjOGDO>rrNSVhlK{E~2)!hhw$9W)eaCZjK-zCbYN1J%=+?CW>~B_M;+uT3tdVx1 z;YLrhT05&P9Hrgbnx-eGJ^Mi`0Fbe4V7>sj={x=P<~ywqB&oDe{S-iZ`lhF|CEZ=5 zex^?S3_y1+&Tr?lCf2G$3eapcl@b8=l^I$|>JRAv)%%5t^T$voqCG|5IJ9rv{rZph zKM+*f0jm5n6WPFjfL+Qab!Zd7KmEyQ?-jJ?%E4ytkN9J1^g7)VG1fSopYR)j||l@O(Xr~^r|kI#PSiT*x9NQjx2BAjims&^Prl#M!Z07(y;bXwCrxo__LxV)m$TcRUt|!MzY^jvt><1K zgBZjQ8hu?x)1U^gR5u?^cOIdyD<}g`s_Z8i#EVl-+h^a)ZI@CRCiNE>^d`3nk=Y)g zbDWqi*0Z-47bYppCb(Kdk&uvYW6@ikDQrnq?~VhSx1qfzHK&ZXV@hPZL)-)qwH+9b8U`O z?Y2`Q4yz4w*kYE#W?i0>0A!Zrph}oAK58lL4kmbTuZHy6+VvaTtFK?kxOrW5{r?SK z4*rjnI0q<6ni6`eQN?n+2N&nbz|N7nKgji-ZN@q368j=0i7jT^xjC3#@o14lQUjqp z^nKPv+)iB^cR0RhV5=A0WJ0LtN)Cu@F}7#07pTCj&c+kLzab zuAv(45eM4C|E|veeUnaTSKmpxLSPAfh*GepTqhI&G1e?b0!=c7CW+eTxe&j{l>Qju znmHlCTfHWSeP~4C8QrRdBMt4j`dN~w&rv64H0G0>VV3;6pM2^OK(AemzOT}qQ0aCH zi84veG`VpIl9Y7pHmwz3jz^pT;9}au^P<+W8mu1 z9&II*f)sEecpt-HOhDQhU~VK#%vsd$7UKSJ;$K-49;vZ~_@wf@2B|s`Dm2Sk!`ezV zaeq(vAIZcM$Zd=?EWG zL)M>Mt?w~~GM_@tX;6ACwzfdzf}GGrqIz*LQdbF!fo*u=GflK z0hbOxVqplD$sn35a4{g}vT+~5kMY1}sK-+cCy|4UQ}-qo=Wpj?UAt7br5QKp?j|zJ z{CPOnse!{!ii5csEJo?po^UH+rXK^q?FqO^ElMiBld6lSlJUsFJ2SiXwxr@rBOcI! zvy0sYfA;Ur+GI#MNL*@WU~+Nyn~6RQd}J?f=SsuxaI*}qNrFJ<_!EdZ2Y)BlTz8?* zhhb>cXRG~p9jzbQDa3h@-%kVSt>n!SQ{^?0dOa#|PGOW1p!;g@L?6d|kq*}Ibd7-K zV*>OedBs}=C$&T$N;U^^pN61YCIy&;tLW}2JYJ1yd6TmVh5#7>65-*xcL|$^^X7@0yKxJV;g@NM*wSY78tj4 z6DJ{CriCVs1}k)@^yuouIY`p0sf?@sjR>+~0I=Ebm%bLGbz}tl=jaat(|GrrM=PS@ zb;&K>^-ulGpE0wGWNEtP`IJy<==qYKz7`%xx|MqsTrc#7ORs`lsqWe{ zKI81x!c0B87X81%|LbTr+X*J_++465 zxSjj}r-}X@GmRgjY}q=OdsFvRngT~t^ApK<&6=`RWPF06%=XFNJ?eZ8Z@1N2ycKXq z0~njwL*NJ>*aV7~{=IvcLG)ATz@oI*UA#H>E?4)yN7qRgAkLLLU_IYYnlb* z53@v{Vl!&~uLQ-0?tgo4QMc<@^{YxBc2ki5^~F7Y+c`Xz~Mu z#O_rt?ZU8R>X7MSu1F_LyKxz*$aUYPwaGEAX%%koIrmjj&Vc9aR^~x zm7{OtshcykLb|;oL`^?^UVlzS7?4pw3UjvNbc+Pj&cKY=oqqNs9;v`+{LSmHZxb^N z{J4o?;1*u9o?BYu2+1%`O_o^`9}CeF0P+wA^N3MbXKkd>nBInrL0wER!iEVu%yzZ0 zWBT(H!)>$HMnlfE%^esGx7C5N4@M^bAyv2)pv&*H>**QCaH2Vg6Bjala~Ka&7}!a5 z_qG9oCy+?&CDNRao--uw$y>h={T^&8EBnb`I{e?o(hM}(A*Q-)Z`gZlW07HV+dWj6pUegfsrvx~d=1DbC$?d&>SaSxJ zO5XoLBkU`}g#tyN_h>Id_{Jtb?QaKBu7B_Pq*an z228y>I|uLh14R!9LUuadr30~~zS}qd5C@m;9J+ZsL-#VxvB=}>ar@aF+qj@~*dO>N zj>E{E?eP1RslV?-cXjJ;pfl1?%TDfPEMSyehYf&)YVzx5R)?gRRDQ>GZ^ET=Lpx#( zjj5e9MOPU4(rJW|EJr_o_f`NUhUP1Vf8y>7(8V0u8=%CjnjUp@KQ>{CuX-0oX70Yz z|J66u(qhlEUzlHC?}l}SAJa3Ksn~^4a8`X;WJbzJl!GTk2;c4S6*uY9!wW8uy_oy! z4=+vn_CsYB^xdtB7gx@eoTat!A-{ZCMsnZVi;D@WhSKXGz-9svtL~*lY3oGpd}PIfb1~ zIN0+PNo?IXx98r-o!02C*;eMY3Gjugzam;aOOE;PI&mqwo8Diz`s0Zf-t&}x`#mV{ zrrCE0_aCvj5ACn^Y30$Tfd!=dcuaKg;TJJa?mJv6zI)vEMMnkN{lyS=YW0PVkwWjT zK@r-ydsHNM((Q@}c08RreXrznB@;=l;k6@d&J0qv+8U3Pqn#4z3;t&N4=W0+60=8w z%_~?D_w3Gba*=kIIWhO`>SWB|>WaaPVB3}d-W&AO>G<5?w%~$#^5Loaonkc2Mv5fbKh1TtEh` zyflhLrZM)6AnhfdhK-J;tviaGUK}nDw!4-ck?h!!U;oPS-kbUjPBoE68=P9@DmB7H zu2E6z`bxisK85Us-R#|9r!ssoIY#N2?W4d2>e z79nk4(v8ws`;aEPm1O}YiQmHEj29_57H##vTz#iQcsC=j6PP48HQx+*ec>P^N3T-k%xQ8H{vus z%ft&ZfA?vFqJ4EbYcIF3&2KABLow&qbWe%PQVlHP*wd&|jcDg>f$z$Y4cOD#04%`; z;j%3_p<-5Cb&*l%mY#$itG;_Ol#52*uGbsOfZ21@deNNTE!|_&B2O2T#m`pU`&^?o zj!ThaPd$DusC?ckWB`!I^_J#A3kGU{XwRK!Y{>*O*> zniUY``KsRV^X$u1_bI0dcChT$0>N5EHx3hvB5PwX=jN!W_!!^HcIrc$&v@hY`zeuU zSx+-lghj)BJBjV^eoH@cQP^P+K|Oe9ADxV3b}|jGkM-NShog+1G4ab$9KdZ)xp0jt znK0;gmy)C|*fqyAJTv6y6#NAdUYTTYjpOg^Glkfds$-m!9%S*0@98$D<_Q${EWQt& zB2;V0Ni-dQu|Y|OEZKBQTZdH)bl(N;Nhuu5Ez6Ip)SfQk*OI~J8ezqyjTG$;QT_%S z17xvrkG`dKA6X_W<_y*wu5uC6mINgU_-TD10Biv!>w9-7&^DWNycgK*N17Rahv-4H zU-P;N9SnPFy6;N1qRdzKwSLEtKlv~PvtIR@paKL>HM?{Q@!RzKD|BdJonoIQCSeCc zN?8P6ANppw;xa{eci9-~Gmv!f9Acy01vO=ePhrL^^ZYGuy|z6tnUvIZN^&xh8yF)4 znPiUkS-I{tK4!cf?{6@gK}NKowk!bquXT(e;*`l*Pna8v zJg7H*wON`EuPAkz&P4-hfp`bevVW)F`whW<6gOZO;hZ0^G;|=zUE4mnGWcN8 zIL85BNhO&2l94Pi2q&3YunYG$4(`0`^K_xYhKWP5H9ci5U`a=CLlE`}rLQd&QoY*c z_QshXs7?c{Q^2N22`x;*_n&@_L!en+UDvaCqAExu(zg z)hSf8D#&Dpde`Ija7jwXrSnHLI>v(>sqnZ$o=RFl`g|UGyo#}HFF$uo3`D2;`5I3d z<$1(i8cHvIy`smXF95HNPpieGTmVG-6U=91H@S%6Q4F|Hm#%EX z4~#a=lZtLU&CxFykS=%Cw!_v8SQ4yj5f&GOX!q)eDMDF4369lGnV(v*fDtDaM!>V= zqEygNU;CX?#is^S^z(k}&;Vcad^oT1hv86*zOv}y@a^C$E#mh+&p@Rv6vHir#Riq3lc7{}EK0v&u@P=Rgf z^bf!N?CNydj2X7(S7Z9KKO87YxY=Xgs+4-$c)fG}EHZ4g9sB*+#0*^C!$8c53;%N{ z2+tBqPiqub%c&j(tlRaft@9!qwz?pEC`sM<17vMKTo5#-$S;>mY|Vi|VIJ)`r6654 z*NwpVf+DY&My-8##oJ}OVR0!IElWwaxsiPN{ki#ytFE05i7B9mr{Vj~-^bYp*4X3g z2$1S+3=_T44Vk0snlsGXRgUvfPJNq=?2hPbt+Ef8^a^$~mRzDQTP7Lb03{6$8ta=7$k?B3hn1OMq3w-pX zM(>>r)9yC2FerLoVDQST09*gsr!v(%(0r|BJ5wpFsbO9Q{b3n@rA>S(%Jo^?P2dq0d_H3e&;p&)9QR0Vyi((~Qyn#x8bt|IJS1^X_8ZC8aIqUm?Z z`G+;IBOF{Y2CJs&9%duIli)sWZeGK4()eZJ<8W^uG$w*UtGL2k#dpS+y#yPty^J zNhqZNd5PqFN#oqjjazov|8{uCDW)vWjlA6E6E|aaTChLB1~I$bBm4Vne5VY3cp7<@ zPj6A0kIy6LWp?mU%w)B@%T(fYJhKEdYJH zG}JqcX}5|pE-$@?EB`^#pXckpAnDVR^jox)T4(%R6MYVIICln8O82#ZT6@nR%Tc)ta%8Be+Up|Tq#GCY9OaO0)0kMnWGN&&yWHmM2zOx6)P)1 zfOYFU{L?U^9-Jf^H5W%Z>E0{7+jjht%r+|%)%F8*nvXgON@)@zp61z~QKmF~fky`f zE84x^3K2)~R;o%Do=Sl8qDbir%6T80?6t*uKE9__I^zu;Brit4nr`(&`c zBs2JaJK|dpg3*G|bRhqdr*?cr3^yW->`+m3KxG!lR0(1Mbau+&IUM8(8gfqt?0^h< zfWz~71KXp59+JcN($4OY6JmtW9Rld){9O@~#N57(J;X8Qu0D@!y-|!mIsKIU43C%GmILt@IY<6?b zLph$r9#cUt^HIM3k%wsTTNR;vXV8LhG~FvN=m z_!}M#`6@&$38SM}^z9)KIgA@cCEMqsJ!KiJQRiei*EVvYmy)u*~v!D$*f)qjjOxO&%2qED9)x4HV3iJ zTf}IfJ7TZ_k;O(fkPuB8NHoRTVmoaAJiKPUDvM<8)5r<0znmqgY9t}{q`~AI*sZAz z8JQ_3mYIkb(X3wA?Sro&(xC07G{n)YkVc?=lgc-3yIJ<4Ppbg6zHz?@iu_xH_zXyT zEKI&73TdY~mPD3Xxf>^KKShf-e(`+6y-4HhG~?wp+ixTrzjZm%kZykkd}Oj*3pt)_ zs6^B)Y*%3IpHfm=%S-{5sfb8REt7n8UbMwJL<&p2Fpq#{B=6S%oQHtFM*wv+xTO-5 zUhBG@7I~BwN&5;-FoHVLcuw;!NgPjopes)v78{!?|$u6S;h2 zs@_|g*0O^iR=K3jL!D0QPRkS(E-MzRf3 zRC4@4%xdw_D!*Xd5e{NOpnHU^x7%|?6I(Cke zX=R=FJ|{xtD-%ExhYRt`ygZLQ+rQyf%7IR`XD1Nx zeK@#AdFob%`P=6~vo70D)F-D%%|)@>8Br!m4!W<$G2-pO`12seRQrJ21_c0d*|e>$ z{mA3M4SEBJd8?3rMRbK6bykVwHPWA0A$64?>JX612U!b1P5`iJ1z|fMx{b}w^*S3j z+`V+KVP!7HcN;WK<6@D_F0EY?QMP8`p3DB_k1&o_ZV|v&zR`VxLH|S3JNI$t6h_j6 zffuS^cbeg9jxNHwq`WM8_lKR;D0D6dwO#;d9ltrVb_tr1U{h|xT!Sa;HAn#sg)+!-2`4n?JZM=sZru@uTY9W_S3 zMSgFdTVrQ{b2wLqzQ#u_TTRm|ddvf&s5NF4T?XaXW}d$rxk#8Uzs5$+&Bq%SSXA2S zU*ppee%9S`#53)E?l9`vgavHzfXOJuZjSy#z2O>-a(>2ZK=iz6K5;1eF6b{r@50us zA>#>ilA?HHoE*ifC)FNzM0UPon~lzGn}r0f@FW+toeCmdl5#N(CmH6OS;G zt}sR%=(V9)&)z8I=%df8w!YhNyQeg32(|f!-HG%G#+;D997vdYNBH`NJ5u^7s||lO`9A^?2Mkw}bL@-CW$apI>1{q)qLP%Mjmylls?V#iCCT8|tg` zoL;fr?_}h25=V;vaq4-isn6kGi;TC=^GX6%<%M3N;cSZ4v|Qylr4t?B5;ewp*e*fV zIz(LJB)_)$xcRTFE>3^_8zlZ>R9Dcw3H;;DcJZG))%TyhU$ZM#xlJuv?GQzMcUrk^ z-#f&Kzv2~p#%uOJaQBL>-!q=O19hZ}*Pirp;+dYISLm0$*DAK5;xC6Lytg*tvCGa? z@W#}TlM8#uN<+4`Go^NSd%?LAaqYr$r(M zoO(luGn&+^LU!JGL&$1B4>9}?+Huz{p?xZ)JjlGvuY5Vw>XaBCVKvo9^h#U-x#wA* zuz(7Ko@yjS(AB;bQJzy`eCTaWW7Y1P>=A+oL;Z*lc37wBWUS}#lWSwa*LD+kv4p#c zVRb0~%24;Uqt~K>-}hBSCytA|L$`lnp2&_n2qCaDU(68yOe*}!*{Xeb;zsi?lF^Nx zKkJ`HS*2H=eqr^+4;yWMUOQaro@R~bFz)%`*&j0fiL4w%e+BIz8dA=77mBgr=0IY3 zAj__w5bSQ}R>9_!Z6kVmwMAjs*v+FAfj9p*Q?*D__u)3;|K~@H3l4u%5#RA`rXtK+ z(FfFH$4r6o?qeD-6z?+ymN>}eLMS#{wG1?1b1X4_B9|F`I$&Ho`5Nq=TwfsBbaqs!}$?wiqFWG&t$hdPV zpqh^7zAo{R3h&eoh$t7RNA9nyK&OVLr8=tvTcS%0KXt_0qo^gsC zm|0UPl+!;oHYXX^J zCnA*BtckHB`}%*y-=FEk#nPvizns`eM;x0RR_Q!G5#cQiY&kSz;E#>XxMzJPTPcLy zOdRkvFoC4OBaJ~3&U7JaGf&5`-<5NeDn=t?by&~fB=A&t;m3y*zvnIC?rt$wn@4+y zotlk~QDx{ryLqBzyL5F*I8uM41=lhlb&4CN#*I<5UUaTubq6B2`(E@@ZBV6jL5AL( z4r0~x*|;kT!nJs3N-)muj2OAAxE`efG_2H0)( z|2To|isn2Fu&*-ut&aB=Y-`&P9W*C40})XhH9fckgAyd0zY$L_2EXuT=->lHdD|+$ z&++O$ya1qGPX$t9guY`;c)z?=KNXYvBJnD$s|yp z6i@ilRn5SXy;j!s!?M(ffbx!Xj}p`RqPVB^++yyc0)3bbrZFS*T7sNx-p~Tw>bP(p zOp=K(+stGFBi+@QL^(`9z50|B5BABMJ9nEWClmE!sF%cKHf*lgSI<<(L$dNq%u;+w}0A#6{*2 z5a>ANk~$_n997>n7wX@0e@E>#+<5+o~*;N+LJx5!8KA>Yh&t~u0+{l_6 zpC+E++;!{%>ce(lHeA2ZC-P2T_hNT15!nGf{HP3t9uX0*<@;wux)nWN>b!}%Ug|K= ziH_U-(O6FPw|mr95~RB+sB%eUCGW6vAQ#~FYy9jka&1Bp4 z=^eYyO=<(yR8(#UAkx~-#W`pYRDGG`Po1Y@;5QM!?Uab5RWdFiTDg-YW-INvq)`4FKk?NwSV%7zOM! zpvxbTzKd5FXB+EQ8c{6kaCgWwt%m@nV9bQ)Tyku%%c)i`c1YQM0n&i{Ge>$&om^_4 z)JaQsy|&!xN)Gg(8|RZ+2*@2>M3_3YIMFa`2wA z__Ww--YtVz_dP9$?`GmBcO85X|HWiX4)Y#9)ekzPKCWH8?{21?aNrH2R=i2?RsbT( z_%8CP3D9-u%{`ezmW45k+K#6njB{$b-ofP#U-JtB*llwOIwAAqnqRLq4nyh#aha(; zt!#E)W!~Snn&F0*mHo}BK)W}cxy1eR%=drK!&%u}Y*x%ILRfKOc(%H5t3Z7AgB+pd z2J-8M{B5$zkZXSaz}JxGjn{Bt2Rw$~wz3U*$Nh_TPX@TRdQ>`0w3WP>KipTFE$+;z zvSqj<@%@fl?Zy~KdT&#RR!pu}&bCzEV1sdvXTNpWm;TJ*u>VTy=g(B*oj&J^y^Y(Z zrUo(LU-LHslIhpjnDpkG8mLhpr2=Q#7?c$XCHQ+W5NOg8v zM9h*uSAY-x6bHtLpsHrxPLbYnHUOLrfba!Sz&Y(C3QhnB$E}770Kp9%*zD@CWB`Z{ z04Nz**Uxpqp+&Q~*9z^_{HtfkYdRQWU(E_GAIjx`xm2(}Uq`~q<@>gjv5={JXfj=z zMk+|J?ke_Nv+IU$$pm%xM{13cv{s|z%IZCHO1dW>T5NxH4-1+?6Yr!$3WkuYR64QM zNbR?*S4m^(1!_2yA;?RkfiFftnS8~8-xpFy;7k=%jgh9zL;0ku88}$REc}ox*GFUb zfFzO9bS@Z){WLlm*Bp2l3$z0Aaa?*3fNkl8#&BS45+qz8`LYiZ$JS%Z>ib4;FKO4~ zhxC@mt04ylpuy}SK%2s>Kw?n=38S5PpD%g4K=)jG;+vh-qVeB>4tmW+o(X)|4*bcv zYiB6xYcFdc2WKVGF&M$1`QJ&fMI%h}rlJJZ$@mzVwnhBl5ZU`YMvA{2N6*WaOZQQ+ zaoy72O6j3o`M_7NANhH_R7p{teu7{nr~tMr6~a=9tbsZ~q;=Zx&}&vaQUP8`Zwp!y zt*jP%j=64>fe+-b+Xgb)5oc+~gaj#bt-Rb~6u1bk4{jKZNr%KxwWA3>tTM4}{0a8F z2v!G)xQ7cG65B8<(sn{jK7g}TS~L(1D!3U%3yR0A+!fRCx z7*QVMamz-1g1rRdtgwC{KEcW+vIOFgd@%1G_zdbKJ6qz|P_a;kW{r!Z$J!R!PS}Jy zt_3~_C~~wti2o;F5u0jgOA)P{7df}+oE=@a&H!lW(L)#mAi^c#MiMWjNN?5#LK46* zLWEqv%P=D0S8w{*9KNb2xWQ;YQvk*Q4YE(WlfsZ$ z=ouIHWO$c&tY%HVU21AgNz(cB>Xs^*`@S1p7utIFc@;um`X2Qt+?x;OlE7=Sbskfx zm*CJs*Q+HISUTIdf(lN}hOF{~`*`T^G*D*hniL~x_;NSIPwr7k5`Ry-dO1Sso0=CP z_sIVMB^~k4ym5X9RZ=z$cUMXYMyF12b^KBtqdnYSMaVzMVVT{M9b?eQS#aDe_!ZSF znzS*{t|B)Cz5Wzz_mJM-J&;NH#ZZpRb*qipZIA=`l79hKKUALVS;KK9WNco)c^}&S zfBMR~>jf!3Yz@ZW)gpVv-?IZwaU{pS7gBy7ME|BFJ2mg)&kfmt)|YWI9u+?kjJ_I= zF+&1+M5`EQr*6)o1uN#no-*r|0Sn0E00Ak}8ytg$O3h8>b zcpY4)VjPjCqPj0GugJgOa)J9ejRx^!iJ2_%%0_Ba8v6ABUaZov>F1O|6z5P=X7e6!|E@?C$tVHk7-B_K)HuoNLb%2t7#r$I$!8vM6;&?kaB>*d~;?J0FCLv!C!=;R)_1H!+BKm zL;`3x!vrh<(||;(j$gj?t_rqK08Pe0)dK0BYG2sqSNf-mkHJqjaKhH$p?-8Rmm?M@ z73}+%kd{r|!Ro4zLiybZON_7x&fYd zK%U23s;(u8v#Y@=Y-!e$@JK$CKMyUF7V|Z^?o0XWl+x^ht2fg}B8-HGb;R*=I_ofy z=$3UfE-c0f9oYhl*G^#ZCV8hovUvOdJoDy(&>##=O zV?6LxQeN<8SG=$^MzGcWVFhtqZ;1~JdM_4`oxjQ@*#lFQ^Jf$264M&VK`nM}RQjxY z;7iyU%_V8{!caxeTJHL`IL3STUl){(G|3 z<5@pNu`FCH)h%Y45atmTt5ij?%AJDN?dlA6W=Iq!5~n55OPSbN19Bb`d5}O9w#b^T z-7gY7wFHgSZl|e&v`c)ka#691R&fOqvcZnCq8Luu-;R0T{j@)>T%tzodjm$y2Iz1E zP%6MCRfK~J3|R3nIRJ|hVwXB3YlR)?7oR2p!Ir?ICo&EA0sviL8!0|R%PU%jo%kiy zrhjLKP+S_6k9nWL&eJy=iujM}O9gY26OOw}Pq-s!e&8*ZVn6n`=G&$lFS3i2a9=uf zM?Q2ftLFJ9sk__vWjagDp-J~FN%?Z!su7R&wQF)(pl8e8XB|qd(Sva{xo#x!b(&el z;MFqMT)%Ay#jr|DJ$a_u}zPO_o5u33DrI)*?2w?LVVWi1+5V zOz!-rlw8%(Uxt)=mqQZM);3p2E+4u>B=7#&?R4^rbc^cb+JC)y1Cp57Q z5?W&fdC@X?!u*Y?2uN1-*9A&t6x^WnA3e%zaiB6cP%aLv1_!s~Mg+XEE=v!wkvPwa zodv5_wk+pPzBm-D%xy@C302Mfw2^o(iT;hb9d+JPA;I~#y>5AeTZ6NR$kskuQQmXM zNYgQP4)dUScWuA$ zE$8N2VeTS_mLo6p=|S?rho()Re^;?^Rvur$;nTP4Cq-}n%lP(vqh|6D_nUZf!(KZE zW%p0e?b3Y$80@A8}=e&BMKK@0U5`Rhn*Uesd>iAJGwG;j;&C_bZuI?UMRP zjBxOy%4#o&;bx$skg>QfjuYAU?O<%j-q)4WmU&kuub#dmUYkq&l-j;W^Lg6D_ECJh zuUceaOdQ0>*X{0^JlOf&h3wZA@NK;JT$sf={+&6p-$d!D$Q$kl>o?oKaF6@awY%tU z0nt7FN?PZxZBB;vJmrOJAN>CExSsR5JN9Fv^<%q@Z5-|cD6-Nq{2F~*Mczl^jl@g5 zYd5|HKl>|Qx#wEak3zH7M8#M|&)d&WyJDLbu4Zgld1Z`(IxsizI+zN1B|>}687;$# zgfsUjMGjX7MuHE^7c+`%t`3S&&SMF-TOG#6RwX&!z9lMJxiFfs(Y?E4p^(=!Q2x-t zxEx#R-IbDz@)}RUJt*!sZ(&++-27xUAwGIW$F$l8d(R-j4=`+I!3J9~2(uLSXKs-=`( zK~o5A-6yhiB1t)XH#Wsq2kka8)>Ii0#CQ`zF6&=zR4%xzRl|Jf^3LoJ4?aU0i7Vn8d% zhrk(+Nv);ZrFI7{vY(|Mty7>K{70snB%vd|Np@HL*3PV+9dkw5wK={E;+?q(dhplf zn1*|+U%>5yR?i7W3T+jvq=z;N&XeQozIL>uom*#Ng-*uM75_Q3jcyN|Q}#Pd3(;Yb zy$tQ@skhZkVT|%Mu^S_CKxUqa?*I|IFft*Zi{fcSX3who0egjt=&@dlAKHZMbNX?f z+kja&A4c7yUHlfO8g{F~bE_-S7@x=Dv4(!{?Ng-9oB{3T(TM&P7v`Kc7IQ;#-4k3) zVuq4J8fX1Zeegr9I!F#VO_kWbb#noyFZUsT!Ol*SaDV36C8!-U#U0+w} z#BM7H!C2w?&B zWWw(zY1lca5}8%FyM<~^+alxR5B&nHG+7@nR zl*iMbR1ZYLC_FAAo-M-4>EPp49EkZ7A@+k3vS#D-ihuH<_a`yXpdnyQ9j4z5qyoU) zNPvR?h>hoz0PnZF5!^K(rfSlQ4$UyJ4O(cONe2MDH9#ls1k^(SLh=A0Qw9*sqKjJg z7h9Up)d;*?(Aw4qq_AI@6{Vo7uL zuZ2gLEvKYzGd8-uY-@~|y^C{uXY@5|-;0@LceV_u^^Y!DVo)(cMDP#*OVyzBSl}(y zBK@O_y%y#F`^#s*MQ^W%CFU7MY&68yrmkGf(OHDWclmYqx+GjT%ilTGpeWS4mpHGI zu5Q&9+^1NYH_`pQA2Cqz*&^JPe93p8DM{FD53W@cV5d~KL@Ii4r#U0Um6aul+=gWa zsxKgmvbqSf?T5F{#c7A_z@ms5sTueJrErd{AD7LMw=SiGiCp(D15o>4zJ*bh^@MU|r6Ki;zr_ZM-jFNE$_-c)MBR5+mJKfL zUkqReQ*Kup_fuJ45Sz0=c)c9b_dVp7pE5=FEKyGl=Y}NCJY>?Bi#mpT?AkmtDf(He z-dnTsf}moN9_n?u&Bca-uozeU(XuNxOK&93AxS7eQG-=AB7|}hgk9b*7aQDGh0r2} zC%3RguD9m%15+-_6WPpW`M1lP8jP|3a|QLQ1U4)`5LTvCqf@yYcq2upP;ckIFsZ)r}R0hI_Bf?d3U>it1*h(dY zQQTmXW9et|N>kuA&UJ61YlREUCIa{Z!88`Ajcx#5nIPRWOMXsnQ;36$)8@Z>?@(`Y zAOs8@2!sn}E0Lo`d5hpYh<*6FX#mWMo=HRaTHt}Y?PX}A28OTEpx~*`A4Y4z*afmp z^w&$FP*>q>ZtTk;=rFZDBk@bHt4fsbOJW-3FVpW#_81($;obi}>nGF)|3y$)Cx!s{ z0ot6uG$Y^AR09Z%R;91Kl@hv=1N#_{P{SC0x$n2^4=}m-rEr&~{`QM(adB!JdjBF! z+%*REjAAs#$bAX-9^P54;g%e1Z1{JQ1F^mN9BF1SKx~6U9I3;F*)7;LSB5@2PPi0B zbu<{Rw;HzMZ>Wz>=?&^wtyef5&K-E=Yd*)n`}tj&PmIzMKYv~$NnzAaiDvz-D*e4a z3;kmULHK+7A$Et=C|+MLF`|%i;04e+uDii-DOtaAY*=rX%CdBRiR+b;ygshwB(Z%- z?7$wj+LZ!h-AwaeT&H*)x@2TEdvEznpE1p#+#m*NB^J@vq6}MV%<>VobduqxJ5;_% zt=ME2!ZZiobJBWY`5^f<)DLmN@q@-%fa;9t#Oa&&40*r{kwUwtKwC$k+eXlLIn$gk zvZ66fopftVM=x*eaS0N6UF$iug<)vYj(97GSPk(GkG{8}VOi0%`b*MPPodSkT9)5Rhi zc@OoX!ZuR?$1DQ^)&Xn>n06j2C|?&z66~_OtM>-Lt26ax`eE65b)!ksbdy`iLfbZL#QQAqGCUH5oga zKsI8MNjDh=u{nS}E>FYETXd9#RA6I9o?xd5Lxq5(UD}XyM zj}EyToqb;g^5QYeT1xpV@=8XbySMosmsmLC5hdd+@4nn11dG~ON`+fURZME1NYO2_ zoy|XS3gbSYvsfIMCx5e7nC~zM`?D6YxcRYr=DI9u_IpeT2l3iqv67v?Vku9BL zzq-bTpO=GS3+fE*cm}qZA;pX^BU6;mFFP3mvH>U6;LB|WNoUqEPsJG6bu%4`nc9}{ zb7Shu5(6i1hV|yV4cSbebv?v@21diHj1}LAs*F4Iwb_5dGYU5@Aoe+iO0I`v3?7U!Aej|BFXyL{%I>b{$9lXtt_XZ=0{f4Aar?ADcmutv9RYHpt!7gwO(j^8h-k z$mcw9Xj+|{D*!P&8de2x`fP%f0m2LDh~^KR6v-5_eOtO4skC8gqP9R(0w14RZ;w)|v10pU6G+ z;Hn|UCuNQ2ab)O6xo-)kTwAR+{Z?Puqx_;=leC^$f^b!i6v{V{>uLdXlht&3>wWh>+|W zzF&EdcTJQ-uZ%7N_$)3!(%$F9-q=LAwAi!0&|Yvr@t}zN^I6n}1KRms;8|V^r#bge zRHb#Bj=6wfebI5o#_`Xlf3%I8yYSqxxzhB7$NBOi_2B5#@i3n{Zft1#suH9CkIVx` z(^1GY?)p_RojHn%pqkO3piQf}>tPMbasJVcZ$X=$k$?rSUn{-HWaSzOQaLrE*lY9A zXzhtLtu&09DXW313XkRyue@#U%8|Q$Ru<3sA*Myk=W=T7;oW!mfiYcP5|3Ro+dRhB zdv}9@Dmk^%+dFT)-{|M`xJ()AF1Ui*L|!A{k-(mH*MB|0LEvs#@#To z#(?~vO@4@I<3?t*m%gG_S6u~=Ld)OOq>To8$G^G%2qg5fm0Qgq$M2@)NAB`VjE*}I zks%i9!@GFLTR?Bqh95tkb&-M|ikpX?;Oy6=n|IwYW4h~ z&BkAFI--wIrnMjAF$~hfn{*CrUf)Y!a;TXW!bzZnB4~PG!Hs9XfOFu^GR(AY(KE($ zJyXa#I>oau;4x4+_FUa{6<3^$NO>%?Z#(Ju@eclgd@1HJ|2xn(jR&wEAz0kq2;fE$ z3RKG32E^BMwLT>y?t1WU^T@V~Am>}YyF?yxK`u)Yl+U$?DCBGjmaSQDes2Bc_5u|O z019G6*V~PorX#rkvIdB%0Z{nRQ1X-L^&jJ6Zv;v0xTCp6o|)LH*9VY6Lxn2@W}7{lhV(LcjD(L_IBQ0i zRr4FlM;ePqkN)TtJXx`H^X_QiByyGU`eyi}m$t&wqBx}BH$*=^uwbNia@YTwyQpul zzM%t6<-*WrWUz|wX~Fb4{zlu%ckjhgwq}>MFk|adL#L5WsXyYjuGl+&ufjThXZL{T z6hE{C+cZ?+l|*hZRxSR64DBwi_(IyXIUTwDgRYF{$%fxapxd71FP6>!#XYM*d*f6_L8r{=m)m`{k>5%ZVe= zZ6tQD%MFjD^#Mw-G)7*0NB+HZgiEZ39<{=a-2EL_~4@?t^Se0{2&vqKW!YK-SFi7>N^=#x5o?ztPCD8>|*Fov*eGck|zgav%J ztMBJy<-XH$2b2^c4FmLu5E{2=6<6H|x?SM2NkcBA8h0zKyJ+-lun}hyUHZ=Y<}jxG z%3*d9KV-xG{|%IRsNkO~fE`F+!b5vq6NoF|#x2*11EgqSF3#{Kj_@$j{^5m%vZoc# zji^KyBwS!076jOrV0N{D!W_BR0cN0T)VEpg*(qnw*LdhtCa&=ngw((y{`PUcrx`b@ z!+Yu%J<5t}x~=LU@=}nw2LtE1d;P+8>kw1bU!32($2yEnJ-_{3T@+MvEr0#8Oe`Up zWcEwbzCA_WjtuP;`T-~9$;g9Lb1`H^$i(P@1LUutjJVHL^(eM}{KsKu5vf@;O1Q)K z$&2*OyWva1`F4*ij*SMWihU-p%{aX@ule>GfoceJHOgF)2&gpTN_I_MU) zTiBcwEq{sM+#Rb7ycztpPs=I4!!FyOPJNj1jy)E%_x<~Did9Ic)0%x#xP*HfuLOx3 z75HQ8eF>LDy4lLx48Id>-oS@S*mc>uzVbPFWp%4v*nmMJdl6aXl|JI!qaM57=oLLU zre5|2ZVhhH7@}@{Z%#e)2B)k;U(_A+4*PLz*Fe_>*N{xX#?$8&cNs3xZ^z!uLYD2p zyI+l6o^`2d3~)_oTwyQlx}qF0|NgC~k{7TXiOO6bQ$%Igh$JIA zbKEnmVE}#l!LqDN2C-L*&$Hh2#qD?AHpeo$OX(Y)3`3sGwpm@3!$v`W7gfkP63$59 zC+pP3iqwjFw-fVH8HOYT%o2%&iy{OD@sQZ)0u_SH;O-iu+8U$^s*pRQpaw~$2#6`* zDkx*vlZ3TZxT*=!l6hWRIhIB-V>*M1+r*9W=p%XZei!TYc#kmawf1Q<3VIl@8{&U5 zp0wmH0aZ#-W7LnwcBo|eZ9A^vF!ozJ-dsPn^Z4=?2uz-f3`GWyY-#f2{xjE9B)Mb~^%f=wsXaZswmzPx62gw(Agj_GTmU1( za3r~DLsYG+U5U}@WouDi?WyMNl0{T<#`(+c-@iU_vhwPGir<8{rsXW&Ye~u31zDvT z`jgjX7S(QU&3sR-v`Uxo_8j{pxOmDcc4t3imGHFo@YT}EaktjE1=&t|S;2iqOM?Dj zI6b|%$F+4^?+m&82PO+O0W%(?s3khazO4bRJAoR7}U($D`*SYvd= z8%3kyy{g04cc9>Q&G~!#T`1Ak)niyekx^D{yZ^Je+G9W^{;1O!X3=)d0zPg$e`4g& z@Su@jduf)PQ+5Bmezw}X`7^K9uFoupdp5Je;gGoVFE08jPUhD7ETn<-koU?kZ+7>V zGsIAP(;d|E^QMk^_et*PffbV?HPKsmxyvCv$mS<{>}>inDViqY0>|}I5!=mYnxfCG^$(x}xVJwuT|d=PCK{b`?MOTp1Ko;Dj3^ z!&LQ>HigaC5rxyvj8ba?CD15Gvsi#-85hX{WCZ(NVJ_^B1In^7%5fmT`fh%dsn>Kax4#kfT}>nB0MAwYaUS(@y^d<>bYQ(Tcvzv&wXM8= zP}02K{sp#r)pe&s1te#JO(B4@m4H|kfOr@r-dLa`Y|BDYdvk8MN6X9~->+2-UuWm+ zg3oM4dOg)tdG^NRlg#wn(>39qnnhgZPzP;*b1!hVp*~Gz?#k)DhdnQ@PZ>hmos(|V z7O6cly5)?)i7LN){dFlx)iJ(%V5TnB0NQ;zusQSs9WLv3wtKG}*M%ZD8I$sy$7zt~Y{S8Q z*(VuiG&CRY(u3P3*3&Q6yeeP2qB_{x#jCT>bbZ5cwjFBmK1he~KcXHd)` z7EUARAO>&4x@IR<`v^NIgd<_@?}c@-Lmk_Q9<^`?BJi*#j1jPx$c^fx4wgBICq@M*)1P+D_`)2_Br%Q8~6d`bt0d=r+_tb`0?9xhC2 zk)U<95lCjpi!Z>O?^qSK-70jZR`)tzrEs4910OLovP9gWQgqV9xbKdHSuqhfJQ)Go zm}g@7ja0iS?Vd0e-{G_-!>_k=&!#l~!k`KjXMc(A{}&kZ&P3;soQTZ@TKugWAURA2 zXA4GmjS11qwTCIKtniU6r}|WB4)ug}IKWbweRF_iEC4KXkV->JX8R1*l>shfe4T(} zXDEOkUTEPpTyehIG*pRv77uScew%Ru8}6-5sXLqeVqHg?yTSQBys-eJ8H@_=&<0$* zcA3==D&+q+(9`DU5eOCl)J({1+eKl45zd$hl+60M^@RHb2@&+Mi>{c738SDJmW-Gc zp0(c=u$DQ-kj9s{0FW~)%zEjsGM6>BJ+TwY1pg%gamx@Q6lmMIY$FS_S#7kLpYL?9 z$v^nw*ORozDzur-Mobn~(53}xuG{15k_{mJk>UC`-j}ie=HxPue;#9J+`pS*cx(S< zK5~tI7ctf4sqLRNkv|$_j}yj5ZU6Bc_lbGgUnCi2Tu2-DQmp==q!6ZA)P*Da?;k2P zc2LdQa9}SP-k9{(uy`}V!yv2w>O(Q(gYrh-3*8}d4l?mc!E3MHcn|5)gEb#G$zBOp z>WZeE?Ou0>`$Ws&hSs;X7iv}o`qH=U9V>qR`?ZK3gSdX%bPm(h7G|GtG%?{o4&`JJ zvSaD*pFU2>Xm0H&rbboU{7^=;=$VN4rQ%Xbc^M^cFZQvu6?~Bk zzRIEto(48BqqrT^qh$*BVQK9~#Rzrn6O^0(Ol54rNadc-BflIpwRjys!x5yjgtf($ zmt4@4f<^wWzkfts{Q_M2>G{oNMbgT4iPQcx;|N_6Qt`_z?utMG|1U-AG_y#2al<^w zs1u=1h{U+;bTe}Mzu!B5fjEKP^3(+>=3T(5e(B$MagFNhkls?tmen2~TZIa*GBM=E zmEqDfAhFDep2Ioxc3h7N>F#~CM-v{F&8rQ*GW#8|dncU41;{nV>av1x4=%4RMd?ga z3iG|qxWdDwl3d1NZJk#T4?evWJlDv^vq_3eMdr(jvfsyUhLA0VSDe@7A|Y2_9{yv_ z&-vyK2P-#Me!AjkdZYOduav4W01NTu^^GjnM8sUpjsrI&z2R9Xu2e##6}>YIH>|nX=3WZVE=5c}Dq-ryOm7QxfG%Mzj~XXaW$+Zt zI=dL1;~!$XL4=vJXg9A;I&qS=z*6)&X@hak`UusWvVB%<=5AS@M=$|Avo4$&Az&^MtcLWqPkC#1h#56L_fo{9jhhn+5X@Wp z7EUFj*XgKmxm|N5tzd>KQ0J8DPUsO1L%Jl*jIA%)s&59SY#?If0hs#t+!gxMU?#= z9*umBn~dhq+pPkrOfj1yCLf5SmUO%_=p48_YU90&=reDAokeHdemf|-gvvQ}==FP= zSxbsuh^{-QHAP#R{*$1wB);Q3$a;LCNgJ>jCsXdV3mk65BStc z-9p^omUrx4yVGU6l<3BT2Y_W4xqo(6Q-`J0Ru;~ma&dqKtGIs%FnC4eyCHxwsQOcM zgP5c$7o0;)i5Mns{YWu%flJYFal_n#lnMGTe2O&-fi(tOs}N@^P{abJi%xch``=If z4h86ELd5q4GdoaxDtD6lR_m_JQD{YER`(SHk3yTzyq-_ZpV!Wl@VfnF!G@^>+Z;U8Ak6??!l2{faQo^*z@{87TNlUmTGD9j60>Exo z$22-IUg;njD`L9$POsC*wCTTbVn%_OHt1wbcxe-M*1QpzXoW$zoL=Z;ylztLQ=S|a zk%M|@Nr5jrEODc;pqfgmzw$X@FAbq?>{ghmIP=r>oh@~QQAqy^JnI_|Wm^1jb8*oDDaDsLi zwNukt?sZWXeL5WQoB%%Mt)3aS<+^iYTE%qOowrWO$eqRp1hn z(}@bRgl0ed9bv57zY~fEBpZ!8zQHcP3J?_xpIQ;YT8d&y0mXc1xY??zBfd7HM^z0c zuyq1}PKu@%$0rs!F$d5Nzw!hRwwtMf4#F9mLMM8!+Sa!x@99m&yw36;nH@_?Wbo)o zHMA@zx)70D_r^9F!Sscm4=#xrzgEJ@ylH+aeF;irN05?l+6oZlT~1adG)4grdrs!y zJanaKXLYX)Z+L$Y5>y9AalihMQwDdHKLMvr&0 z%MsH;s~j>AbmDc|ZdbdM?-^4`2~T4m1-+!_^fdL!Fb}p<7{o2jhv3drmpbW|vn*n; zn6zM|uyyB#JK%lAJ^F|$FQ=VXjQwE))a&)TLe*O=%&f>ii}rhYt{u6^S6^g`so-s= z9X!@G7H#m=vS0EwWqg*%>CD6~szCgT<} z$(OzT0T?igik^#r!BTpMRG+<5(IM?B+yMC3}6h zJ_aRb7*v+Gv_T^Pmd|?|Q9Lk7EVw3uU1o43Z7@xHRfq})rMUgGw76)BmiynOe4GGb zX+1)`PPXWM!irG8HWNw(P823$$K6@jU+2G5{CX|I?(K4F-{9%}vaRbuMur*H*kM=g zd}d-jg~Ou->b3=nRZpJy@Ah5cE@RB{XhWrTbwpee7il3iE7^z$?N$HOCFK5v>)6t; zAk)s#x!ONkMX%Y1a$G*J+H8PrDG1s9H$Up|IA{1klB=u){j^{>95HVx6uV#`?_@zXfTkhi-!xx+EE7#q3 zVkG#`o1snq2afJv#F*HEhY$M0Wc%T^bcQ3m2D5}7NvrA8f4 z`_+0sGRDJx_I+I(#iSxK&%oEMTbo!ZnYq3n3n{JI?)_-EPqoiHa0qY1&43Rncjza8 zHn4%EHNsc4*ej z&!NcTJ^T&0KpnA4U*F{M!UaDhknKkJyW zm1&UCreusj_19M0sW~$8N8ogV%r5JV!`)b`1l}_s#BSb+nKZ*?)*Z|Fo4$aMu0cqB zqSJp>(@x)e&Kqwso}~ZBe80a=ZDnvWW(T&rVQF(;r`m5`yY>$A)r*rQpA!49qrOMQ%`V+-puzwJLI|k>RZWum_&D*;MH7ZNOTP6 za@G(AW{QW7b!^U36NNIOzLA=MNc#E}oKeM+)=}aR^rCYP?mDk4-y~joRfJxX(C{ZC zZ5GZ_TRwMk67b#xz3wkkdE8opv2u5zz3yP#gD4qsflG}EvYh8GKh4@c%)*^E(nEK; zG^o6hrI#j)siJy{N{THL1r!4K9L(U_*+V9e8Yl3(6)qyhsR^G%mQT$k0wUG`%kx#vI}I1TU$pa@me`BDgTu0HgG5Gv z3;P{rKYTLVMI?Lg(G;Z*bS7ANZ?N65_u|dlR&k{+<7R1cLvQ-gz}n)6IQn3!`srRg zJ(|)TdDLakz@Fb5yoIV2yvDh=53G{hU89adbyIG!nv8MJzwjr#VoI~`JO4G|)*IU% zy7w02_rz2CzFJ<}_;5qBn_p)6(7cArLqZmU<3SDJINr)GrG~q_O;byEMV&KvIV+cX z8!iPRS`=JW#ciTMRbbK*`cf8m#7kMTW9Si|0r9rSzDa}IRnu1sD-X7w-8;3g{n@MH z=5^nl@9kf<`3~c3Z1-UD?p60%PnA9DRo_mjdw6B<{OtB)R`-5AllrJ|=VW;r}eLhok^|(L0{e<-?+rWk9S$yc{myLJ6xzL+` zFx1pc5tpB9ueh^%A?v{*p8-P-95^(^f95?n_sgxoyiG{Qz{RYlSl@?5NK2n*?;PnYg zOUi_5tLB)-e1}pl|E7IsOu}Z@hf|AWkE25?>|z?ttWwwS_RZqGvsM3Cc$`aWtW3zt zYJjW`OBUBN(#mz_AiQJ%=U$l;7Vpq3S$xUZl|+uTPZE-=DXB$^kn5+94-}Mi7g;5x zqzD07M$Ok%x)A72@sup7YQ0s!w9vz)O(Yv%_Eo`q*zc%lB-|=1tq|psWyY_%Wm)US zU6ol1H%4mG@E9X{iq$|4lpx5iX(Sy)UN@y5wOnb^K$iYI4P32Xb!h5J>Gn@hl2N<~e zsV*8WlE6_(k(S*mSF5bBl5t9eAdHL^kopzpD~lVKyLZoz-7Gq$Nbq5&=O#pZoYk(V zkKtE3!?Eg;_gg|V`nhO!UyZ9}d`1%(?Y?t;PHFUsW=g$VhiC=PO>`BkbH60bnfB@I zb-VdQZ8pEM`wUQYw=0;j2YCx4rPqdM)l~fbx3iP&pgtGf@7^p|U5+f#)TcMduHi8ya8S z3!E+*t^KXj$WJUADoKv@E>DUBuzC z=@)f>Peg~YxecEn;ofOa+~$kbXs zX^2yu5^y>`&};CghU zTvI26)_AM2VPzAP=7CzjA~~(W+#AT7nKEMxtkl}DjsEfYp5Yv7V?(7ki;PRFk-O2` z!~Cw<;+Hb!X6ctWnR(9DI6)Er+m5zMcb8?&z2EAXN3rA4!MwV(p;=g?tTboR0|W)L z?DGbcTbAr#!o)23Qa3B!)QzGINl{10GGz35`M8R;yk|ErFH3i*XKn#+NZA;4jTpD6uHgYg(OlWhTe$V~+FKkyu}pm&<6cx#xXFT9*=&W48!|d-Rf9 z|G#6WO91(3Tk&NYx!qsFN_N_XlZD6S-p@4EG5VRS@3M1U+PT%)@iO>hnKR>UkJ^(z zRhN=h!T1nl+BT!I22@g>Gt`r|T8`P}`w#DdYwn7d-awvz`JiO~#|?(XWrcmX3>Q+01B!XvHCtQ zrkUg}Z|l0^=4Q5CRg4LP#dmk_lhKa9r5?Q^tP5pVQj}$34u7O#d{_NKZBOep^CfW6 zT0V8_qvHGfO-L47hKa_GSeBY9yo^SSxxt%KK!15>69O;Pg9N4ijElXjWq2Q-az*dy zEd+6$3p|HHJ>(iX)*dM2+YbmUyr&TOun;?RDCJ}y{aNb3@ zt2Q!et9qnMVU>YMj+yM}J1=?QCe_9tKC4A-7gg98G?vdQ?dG zW`^Vlfn!9h-KM=?r>XR%&7dRF_=GxPg*yjfDd^@W>1!DZed2@I3#^~^4?`}^#JbeU5jrMoMWEum z$JDbg9xdTXforD^E%p6sdr1)P<IK%#4cQ;pXBe*34jK-wX%2FH@IoN8cowg|IF$)TOZkw>92 zE?%SZ*b{*(Xx)DZITMIZGNRTgQROMfbsQ*AiQ);7p#XGr4z)`ICFr99 zgh+o4Is<^#UW67pPy(Z0qsZMH2<;Ye$jh-`OQIY}s1R;ZsF8hl0QD{h8Lsp>NWo-j zP*+}IMXaO>fqw+cxAGCxiGnY1a7VeA%Y0Pc5HgyD2rwc1jfj79wy)I0r5Y0}(vSrc zTM{Mc^6PbcWx+Zvr$xBs7|Z?Z7x*UWzA{n#Znw?*nvnK?#s4oKeLW5R4GUd70x7wu zRmy}w!0+c`D0n&U5eEoOOr%Myf-VBtEaYAj@sA{Ai3D1qFYF&vL<^x+#-bKcYKq)d zy`&+mTwbytq_uIV%~H5cL!iFIbJOq!J^XA1)sMVt{0099NlLr%@PH8CpiXPh9BxyY zXGP1ORIp9$>PZlN*yMdwThZ_r)S$$L#pCPEa@6klF7p&-!ZzsPP7Z87BQCWOI`r@g zbyJ^crE*nSPcC6m51-}W2e>$#C^$MD>^TXz>sPXY=w0 z?v@w0AB1q`TnVQ*uunq>)e#FYX8Ig{8R*!g#|tjtOQpQWTv(&YLsY^2B~;$9eC&x~!kpMne%dF>%31+q5Bx8k!@ zqzxwMkxQ|S1bL8u;>O&H_X(IL(UMz<6w{T0T){!U{)}8}s!b%Jf;32;WP42@GRjmG zDnf>E3sxGDM_eGD68h3nm!X8-_mK9aV7irVO-5YdNTe3TuQgK;M?p6;AW?IwND7ud z6f^)h;^KjMXPi>>LpBHR#RURh0t6wr?j(Q(plPk;2IF?WQK%;cmCD5=t74=$KB~n1AW7)b5R?GA`2)Ryd$3)Fmwv!fU2tWCxKsKQjZ#9OQ(B7z?f_ux zrq;TpF0gA?2J;2*NKZT~;w`E`qN&VTXzi;9_NlF~%D%l)VB=-0W4k0LIPsFwTRzczZ(Io{IXPP=5&vdI z46vkT24&7TC5Fd;ZN%;DYJDWdpVK-YH}*eEhEHqq-bi7WJjW+M_&bi}dy?ZLj+x#A zFGyhn%kmn@woCuj2a_Xk-S({N;ydn|0S-5Hz1^3u%?MAIXJW!rmBXJ-mX}$unE+|h zhjT#3#ft>!qqBF(ImR3g{YmsxWEtk^UPi*e$v878QIx`NGC^Bb;Fg}tD>&WMQc|Sz zQ%Gopl371wsz_VIaNl!%pLMKiX8ES7Hm^jQy9aQQTpYaWq>w9&R-(o}Az55x5DDck zto*hOwT6UIrtA=C3XU{)wauY~LKL?Jb($TM#)Y?WaC?p3MQIoh$&RA#t*ez#5-B>v z==!t?imJf;nul_fAVZ7@j{tKP$v;UNRb2zY>roQ}C{zHFtVFEQbW)V4T`DM4>$lSM zv!e$2pE+cR5)4vxmIopWg-Cg8)H)Jsmp=OaO{hhae^R*R&3jh@kCx3s{rhRTn-HDD zLHqWi&yqG5FkA;eLU(hJ+y3^_=uszKNzb~M5)oR$k`xsvfm|>vGmQ;1v+QhR-0X^2 zzysH$TCwAOa(YGnszqD_fd9cnNw}DLMC)qiKYNh}E~z}!+D%CpaqCg|HWPN+eRs7a zc36qK^%ncn^=%Oifx{-~ff4`wJ+}5QDB7G>0`Qkr@Q&s9E|F`SEBqBb{wps25@$8{ z@0O#$Cbbl0?$_3i1{0(hgORwX!p(^iUCfQ!60mbU*l`iz0%Yt2EH!2aBm&X>0N4o( z9pUJ+9zx^glbb64sC6 z-%G-;8R65*?aU17%qil6w9V4(W(yVW0oH<_E&kMp2PCaEa2xWJehJz&+ z4Y+PqK$8OhOM%cQ8;?>6t|jlMvoN%W>zc6z_0=yA5;#y32yJ?bTy5%1l~$JLYcnL; zCLxrpLb3EwKno;LAvb7IkwRqqg+m+k8`?N{hg8lUC3HOu5=fz7QgpNyl_rEXNF$fJ zH8Go1CpQ5#(Z=Y-*+QR0WY~)XHu#l4c_BTN2nQ|7k%S0U`F&}Ic5_gh_My4lZH)gQ zRtxI}4kE9HV7h{kA)@X68n2hH*oEC!pRHKgY?KuK=;o>gbG0KrUzX=^z%nJ_sLCTg zr~8@$5^2UFSU8ag6(9u;zU@XyY!D`kNg;pr zs;0lNZQKyW`rdxPQY9Q$n*xI3acU8)F=6#DPal=6+SHKW#lara@A)bKJ}ino4590d z@Y7~ogBHH5#5XQLJIN=GX-Y*>$i&6>scBn8rT0R3{=UMT#1+is> z^c@%98W3^eK2wlGILrF#*UgqsB!tr({345P;yCl`zqxyn-|}FfkK6vljEI)NCK5b3 z3|~_@y^~lj3C%oh`+zZGSh|_zX~`G~n>d0WxhS0xb20faxs7>btJAZri^-Yke7k!! zQTCaL!1xgOa6x2n9f`ohWhfKYv7p^cT(iL(X%+8p@c?%ylU*AfT_$8V7DN`{;YKcg zPix?vvXjl+D5cWxq{%x!!vF99Cyj%-I8adj2&o8!ghRTlSdT5vxDr*ZK(#eMiMaC2 zkRgP^0VoOpuXLca^3pC9ZksCRqz1hgKt&0mG$S%z@?)k1`MW%FlO#4%ReRUXw_Rn5 z`5Y%S`d8GTn)$e5O&n8+DwZNb^~ht*=66-xg+?ej<5`Fil_0E3p4gJ~A7rSi+Czg3 zP!{k++H_z#U615(Fff8V_Q`dR6qBKaJXmOZ`g1oHy37;p%Ju%|SNE9?SO0xAHZ$I< zSQ*IW8EuQe(oT}E-EY1@CPW|!6S>K)k>%I42~{c`{kdrTnCXQ_`|?KmD#^;eb3*(N zs?#4H;X-&f+NkLn9bZ4}!abs=mHt&MV&NaD;G|^ej0roWjYt_^*q--KKk#E*duP z-m)`uaAVgmdoSc~q@3P<)>C{e{Pw1E*^B;9hc8AYo+)}5El)tfJ*zX{mRwF4tvP73E_71R zF;-9UPE!e2T&vtb@oVZUxbZ8*IOW|q6w@&VQCvTz4{54tZz!GK#co%>?oixjEZcM8 z@jqhNwQ(Y+&Lh8jrkYhgKU3{!(JTHrYFOJrbXA#C^q$RS{MukMC>&4ihi^{d(a2n=S6(7E36=reH@Yl<`i+LRidU{?8OV};Fe zsWa@c-#Amv>@?`mnAbIbgn0DTRA8NpxI7Dn?`EB1MOvD;5`T>|OZL5DlEa(>pXK$g zgF|9moMs5D_o1<6AjV#$LSSfI6_Qhkm7(48CRj*IZkm}f;MTM3k;Q2t*}{44-Z`>B zYv`oMX1$qXY~f|}!cWsyy3n*NH`qQHMZ zsvze+s?oP#ZzpxfFMGqR{9e_$G{>%O7|Y1w-k)Q{Y8o&t=~{hftKZuGlXY>0hBB zRLkCso6!LY!CR{tw>2J^V}+U?#j)+f@-FJt+qc}9nIhfwT|Id>ubXkU>6hEfcJxDb zCB3z^j0Y$BU$<{h3B3JP+4SX?Ek!}C?tY~iOa1PP>4ts_Gx~wYwEHRV$au>}egniv zFBkWIZw~)&`jA%mq~7jv)}^AB?ZzXJnNa&%&0utG=}(Xvo)zv?d*mS>Pv^qDH`Cq?k4?h316 zmBMq3gA>k^9ELTD;4!m_(KKRv0yyKNyaksk!yGh1e9@emQreBAX*5-A1y|}|cB!sV zx(JjYKElW+><~H#LPwTDCvR~Hr?r(f2k^LvHk!kbsM;cbM!rs?v#e&J>=yw1o(cK! zDNZ$OQiKiBB8Uk_mB*YB$u#gRxi+YPST#l{d}+~^i)2r;;2oYuLZQ@S!{%EFUK}1? z(~R;d!C7nqMk$@)PHViK(|_B=MHbWVY+4NaXl?dkc9#2ECs-hIrXI$IyK4|X7ft0bPq#sA^QjugwWJ3Bm+x&**T)wjsi$l3 zKAOc!zao%&UW0Z|T8<0pJ;P}Gtni%0pi|2cxgVaP-OV*?A(}X{|U)*Su-7 z+Xa1?=Lz!R(lU_VWvB`e@M^akUihBXTo+8w5Fczp+L531`h-1$Q_E0x79LP+^hx}8 z7nj*xnhy{3pT%uU`egf~lSN?LY({2QrN@(?qdsGg;Pi=;^h?t59fo-G+ts1W0mB*g zie2zkR;pdRt|}luo>8P-W|voVBXH0uUY+*`ec;Hrk48sKm>h8!Sqcxol7%>znUu>= z*MyVnINHFyVqAZn6CqeBqm{EjvvCizLdvo+h}I>jK!+VBMHE|HJMdn%zc~pLZT6rw z1^s$A{Ea%=Q*QR0Wd@;L;I(JmLILFJv!dj zrvK65(xGk?bqKC zp@Rha_9>&v(L#*ow#hrZ)4OfBVsu#BlDxBd6v;?CZIL^aWF1VObo!SIX9($(S)szS znQ9S^)lnMtPUs*m%DE;VNi!R;Ni3ujz$3J(kn|EZ?znc;>k3fiH^pcG(bvp;9+a2aJ_=LtjDnsvI`WJ$$0kJG|Eg6^-4lFne;x7~7EuMWL4g@+{NDU!R#ugf zzJa?k;>U9AP87&+M@vDgW))>f6p5!uz_YDp&n_DolE=fDNl%DmD&n4lM%KN-fjq<^ z-eXE*KjGjMou%?P5&v1-=Ribrz;w&eO5d^v{zJN_qWpC=hbq{ z^ApRgh7U$2XWb{3PIb|mYodP(sct+`I!-+$IG!|%sDHBXVenH;_0~c4l}D=FtH-T> z@VP~{?ERBycOu-vN7PTKCHXFIuZPxa#CDl_XGV5k96d6=+Ih?=IKaFPToVE(Fe2BK z#SvDJwBAw{ZiRtw|K&!dU$Pl9&@1zjLB4u0VMgw)C8Lt$9yM}@38}XLL>&QHDj92D z=Dnl`-BtY#L$2&LkQk4B&MaAybTB}CA8$W|l z^QaJQReTw=(TrZyR_!jpB$Yv#MpWzs6yMgixfjYXRwWFf<@Xf{8cc!!dgM|0^1+pK zW0jvCm6TSQQ2-?fF-4uL4&SP3sI5H0hDJ<^cr*A-O9;{1#F*q6ZL7@R*~G34S3;-S zumWji&oJbnU)}yY4NpR)14wATx%bv-va({|Rt&O0iVhm2ssLI6aJt9ViYrajoanlZ zo;5QIH8r=MD2m3;l_{am>e|?ORBD?b&bm9LO|e1gl{VjVq^zpQrP54c-K(rT86Uh? zav)n(d%nB3NLrmDs@^W~z4Q#;IJ5*k71n-u16>aGJ11Im=pWo2yR}8eo0H_LlTX`C z0|(oHgXtQmwEI$uM35@Ql+M?k)D6FThHYmlS4;kmIj35gv?O}3gn=J5OkQA zg{oz%HO{trge`!u6@qs2D?f5Of{mJ>8hOxXnXQ1dArZmU%dF2|@YW(6C7`>AWUE?E z%accWvm#F;1I)Jmh88cD+@%f0wr=qeb%s6xgN-u(Z|>r2k#WjVKVgHrHGTf0`C)|! zIi_++Yw?{2Eta?-`Hh@pl!r>e{~bB>hO1s-Pnl0@HcVEm8B?!)21eR8a-(QY6JT_X ztGx+;H2_Tsz($#a0bwZx?S}5ysANpGe?S|;VH3w*Bc1*NI1c$+ewHp1AZ)lGD^0e1 zzAKK5Y_~(Og20zyN?WDeGYt$*lQ-P2aMiDHY2#TINGWMid1rnER7k%vxm!0fSX$xA zK9U#&`6(5l8f4ioiaHtexV+L$ij3n_CJQSQ<`v-rDC4ujPXHwfD$~*wg?dF|S@rz~ zwI_$Fj+dj-Ir>D&>g1uSIDKW-XVhcz@$Es?8C*=sXVgB^l?+YA?i#2>aCT{Lckjme z;~TY*8B)1Z25sc7eN=ZNiuGvwWiw6)zp3rgw@-$D6Jyg!(6=L1KGq0^Ko%Dzqog6f z0+3|wY0sCJ$z_mGswkDn&{6=)mQh83m2~*@Y6NBiDJ-bYG*$ju7?L1mW@_-6ZBQJy zGSPr8{j7+UR`Q>tzL8l8^bVPV%U^2`UKR{}TU@hUNRx9Q9N~slvL0 zpQB2pP^JiR!_EA`^3JvMQwvecp(^dZAJsyVRbpIe0^6kr6}~Z-f3@ zSoz1c$|u~q^fpweApQlXDafdDF@Uz+G6xgYJyzzwTFI2eu|-{W99F0j^j13AazT<% zW^a-?rd8lN6!a@{4}sjVM*VIQu(>0%2gWTwBXkf5dv{D-?6;^@M$LdF0~IbojZS!# z(*XYyi9%aIteKWM{{)7z%z`HJe+?)Ge&o-25c{ueS^Tp9-30&lUM}F!T_!T}XH-CA zAa56RI#C-Mm5r?9E3f0ql^Bo4;x%oQxVWD&%K58 zR-))0U`PQ13JNrnUkh`jfaw1yy7RCk)An!R?5iLs;0l4FqM}i$nPLGhskxw;TUNMc zW)7Aamea7f+d^iJl{Hvq*jS;N**0^7OJ-JRW>nUom06E9v`phO)8p^G-~apv^$?Hy z;yTao`8hq6CT~*xhblie6*}gJKi`pTv18W_v3l-^R#CC0EI0so^ z;g^NJ3elb_nxTrCqF@1A=ichu_L{l|X*w4k>zH7wWI>N^xIhDJkY>b^4v-)hW`nvKpMs$e(!CF?JMP zBtt)Wc6W~nqc2~wkyH6=V@>VE$qy>zb}4%M9DfHX`=_U<5(z%3ruuP@2^dcMbL7&8 zox>@d%Iz7krKF2v>rEd+t^Qg*JFQ0{*|6wT*q7ide}=7Dw{Oee{eYw75XB6IUTU)j zpd2~O$39Fv52Z&!8Q~d@3j)zo=)*xivlMGd*TY%6)2ktW6I6RE8vp%6bc9G7^rmW zgR+euoyU|rj0m4*M52ZRp1AMwm94>0a_ZfUX%L81#p!~_MIz%CDq^Y-QJ_US43C-R ztoe*sYOYuht_<$31jX>}oXYsX>ZC;-QkHt~iA}1w84<^=OyVG;FCyb5h{$8;C*zk2 zQ|U41s>-z0C6Va;fz^Qht&-*qWoopnrdoWldb_q-rruz?bz(mknVedYIv0>`s*Kel zQ*_7@^~A0|)Fuo!MvV~X+g(#5666(esli;C7p1v~lCRiIQhFULic5f3=A^_4|{mq*yWRbdNYg=W4#lBE; ze14{C&EKh|`E{uPSJpzaPR~b=Jv((mou4q*(jB%@#&I@Zc>7ZEG;Zr^ zF=(ZmR-5a&?Q<22-BYJ_1TvodbMey%H^q$j^?*~mrE{|cSsE$; z>Aj-`x{ckN+wUc3x*TlKE(|NezC0l$#)qfDij26;WdHi#JOCnTJbs^W3?yKq-vr2y^Z$^r(SbW;NV=8e6c!S_bo1D@X$N z!-!+P0vLCJ>8h^qR3TiO6)tK8;;#B_g^M$1J6P!f$s&?i=OIzMhp!unTf;)0#@LU zJh|Y;19`KKqel_xsmR4N*w=uRO`g4wd}z7_7U#S&N?KWdul$h8X7{S=)Aqkv+xGpC zP??&a^kNk{=UnAw6AEGN70*F#0IKBcD|fe{S+i9`w@@1b(fcrgr^xsaNyRR0^=KzL zITF1|QW?>QN(S2?smPL4d)s*Q>-g$$^XYi;{$wNaa6Ne7c_j@aH|4*|Fjo`~1_uvU z$=_GRG$T@E6=^aASpwz8L*uq7Vo8eITcGQn{u>3s#GER+yZ4-=zY4X;j`Nd_nFp$D zJxKZ6(d%elX{Y&jZ@j@KU*6oBbGLhvxbeD7cj5Nufd|w}1>baOp_K=px9ym^aNVo& z{)^7UFGZQzYrRjNkLxXe<`TYiGk@Q=!c$XCJ|}ONUhKPia_bJK;KIu@+pgqo-^}0s zG5z+5I}e)p8!zQ#_Lt%}dl!Dny8pc0$EkNmSeK#!x$bmuVXwAs$6FKKD;|uEUj2rf zUsrgs`|d01pQkSyZ|ZWdQ0?ixaeePErI+73VM_VXd-c|1+y?KIE1TclkJx_lV_J9g zj|Z68ResMWcJ---cjA7Vp!W+8#c3~Tm!8@5v*4vF8Vzg%N{<& zTp^p#WxiUgru*$Tnm(PaMkhj!AgqPo%@x+d7=5Tsh(!}_vlue1B8E+{?mBNSFGhl4 z(3wiR05D?7&dB+@-%6PAsoy%7O-2xL>4I=Vv}xraEuGW)f)*s64y7i_1Z$})Oj4*_ zszfnhv!Spn+%Ak;Txk2a&V{jsII>y*euagJdNVhJAWEy_FDy-nE51sK{29&HV+u=vuc(_IJh;7 zMOfDbpwfo)Uf`|@mv6{-QcB!8;;JgV}b{^)Fx46a!J!!kL zEa*aw%R`TwZT-=|p|>85_W6aI_|Uf{^7EkkWcjDM;8T2?D341Lkr>X^0N{(&>IhHQ z3O#I`o9}qo{gP?dJ+76IxjWSQ6a?OK(nH!7shBpgj18X<^_H_tJzL&aCjvi}lpKeq49+%E0pa4LzxU z;5;kGJ}XH7PK#TR{*m&2tL^D>m1f0)ME|4Q3}3>|#J>InwLP7!_5TrqT@6635&o52 zEepF+LMCVc72(EQkYM3~ZMvVzITft3n}h`5i-u8h2fxMyb88eJ+!D@&W$?zV`{Za2 zq7`j7YN%&uo!6v|peK*SFVd7|xJw4Bc8w86FCn2F#`GH3A0Pr$`W4Yb=e;h)Up!-djyjIPvh zuSH`5eyvPNJphhDmlhK?w7qh^8hM8gT7PAJ%FK7VJ0Zp+#NAV_^c&y@)u0>8B4mw@ zSJmjCOG140wk%0$iHGdj?lIP1~$0W-izM~9`P#r?cqV6p|lgf^}M283#^QKw!FR&P&xf+ipWk_ zW?hIm@B2-?_7l&Jk`BCLwCUe+`;Ht;Zk(aM%~}$`8z+iQ2lF0l20d~XRZl0M!m023 zh5}|;hf7CiXt#mzAelm(z>o3iL~h}vXX3>g7pJFLJ{yBrwcsMYrRQW3{0oXhg>$~N zlO!j(<(j**6%M_Iq_pBP75^i6Aw5&Galf5aQod?dt&UV!fBUL%v|L7cHveG_hR~e@ zS!=jwMAkiG*T48eT$~1Ky>LzC2@xJzV$xfw1>m3*fcp2YsI(Ah4YJas)SxAz4ZSlb)U-hIzJk^r< zJs}@r^^5Hg82u3w?T3LKXj)NW|LU+%Y%~z#j(4o6fhM_W08E3pnk(l-rRZGA`C;hX zNw5w%+<$N0qNYZ+AqMBEuqI|_dfXZJ*s{tMR}}TBl3(ekjN3he^yCyK()QJjb`STl zx>J{StuSFYD~&?J{z(_wX&@!wmKdL*JwvnpwK8;jF=5l3(C)LTGH563@UGV)YG>q~ z5KdJ>`Mt9Yhrz#xn574IPq=1dW2?W69~n=|aHVynUZW4+Cq$P+ZMOiktJf^E-mPvw zABjImUf{E4`j66xY369bRKnB)s1eG6%XnXsT{SVyl z+h{r4`H@)};=6uj%KRbPj|8Mor`{&6r}WenNvU^l*TEYSxZS^_5q2MW*g@)`pND6( zzox1AMt(M}O?xkIOD65X0|Wh|h8{e;9iL`$VXV6w^!nE38&3w|C4FM&L9qElO|`)Dt^K{zp_&fP->V+9I+z2ec$Y~_Jo1FODJ?Lkmi%y!I4BNO zJAO>x6XPk_g)8ZS2aPOyplHp8hCW}3B+mipK*1WKRuB|3EX0CW^I7LQ)FgD-|Lp|e z^mz#y1hr`e>Pq*;7$H7n*x)>=sR;ZkAqKvLS~qGC0Sfzd->$>RR$oV06V3R)9hAuV zCI3fH`}==UB&OmNLN5=55+~$jo}Q@N;ykLbvIuAv2w7nG=W&g88?@S9$j%Y5!wiHl zrJV$7SE3-#*Vu4{1PSotZU}l*K(NS5K%`0ru?Fti)Y#X=DiG>4LLeHG1N?G8U7UL7 zr@%YqZE{4Q-tsi%EboqmKK_f*q^XQLYZMS7g*N#L$CmNoTK##<`meddYMGEwreIGQ zZJa>_3rZIYseoYln9vKHY-()Xuu6`}7eM?J^dtpcD)MN1Srj(p(FaCwlpfl}(c2o4 zFUU>nKzc3IV{{|8#PEYw^(hSHwg9;AqYuAN*cU`&`Tco2UVzM-i*toY4V=F&$8ZO5Okjl-Q9>GW zuMyqkIR%3hp5<7fBs2stc#gv6lHl|AMdD6c!Yu>0Md^{0;;B7a^pzSwI{dUYy!e=i zZX7C-A`2yOFUtncj3$5CI6)qY-Nx3<0STsV0NJsh=w|JV&Cya*P)*cT-a6i$TNWHl@cx#KTpOypxmx z6QR+xH$~3(Qnc0qok~to=$@o+?lh zdhDykFgs9E8wJ{a6i_DtB0rim0flzRPZ>?EW&vFcJrt;S(>mZMo9piDttX%qfx>oF zekoIriBOVFfSn9VCJBEW1{MilC35Ooryu1ghgK^N6f>L#-=rpJ?`U%6IEp(_n{d42 z&au}It)1%0#|7p`PWi0ItC!b(>F)Y8Zffzc$2O zJn37m^qo^)UuE^hPvp}X%6pIZ?hIYtEpipZ{X10=$XCAQB3BD6pOdt8!T^4taZlB| zq^7(aTnzOY2=$C~x$Og^=fI5=sxjmJ@bQNBTcQ4=s(%^{KGJ@81_9A1fEUV*qt4-9 zoqcQKxitpABd~RILur`j+yS^3rEOc`q3oykb3UQ`L2guIxW>!4_MIBgBAOq~3GN(9 z+(z^n=;z+*_Z&sMY1zEi9kqU_pDzvFVhZK+&UnjU|Ltk=0YpoMY#+-De)GeG!qCkF zp+Uu>09`1b6dv6?5H&iKdmSF!2-7}0woUD%J!74}LFL~Vu`wTcKLZ3|L>?2TqXuGw z`LXGDQ9>KL&}XoRX%G44gD<`|XFZ6Q5Pv}?$&oRCSn>0S3xLZzLDpUfVNQTr1hZaA zKnV|xiMfL|3Ch?C8{IIWW{Hiw-DXOj_btyx2Bk8gwoJVvQ@=R~ssOb<6X0A3rA+9l z_6lW^0$DPB~ls14EsIZ+7xV=_5V+`yZq197>@?CD_ zsdt_qrex^Bx#&gg2WkeaV**OfFf>Z_wo?#$dm(`#Jme;H{GS}nu0}dTDRV&cs9?e0 zzCH(_l*>u0T9MzI(E{i$v;3#cimEqy#l<6<;SXpDWS?-HX!VD+2`rmSXndh{o5G<* z=y_a^vs#bkLU3wq{_;rLR|xp~TTOvi{$X1Ibdg>tWSF4j5(D;Y5>~w7a^!}~Q$joL zQMcRN6$uAa9)6W{E%=C)!`x#jnZ5xVC-9SgUV>>kk;`|#j7=syybTZytT@KIEY!5dsZHafsDDzCji zxK`SKdqi0R9zmp(EfSUY#Nwge2Vq&dFwY(ps69#s`d5$+#~1edr>+NEf%gTmaW|Bl z*C<~=mGxQ_F(V3?x*(n;aPuE-wC{dF$(@=A4T$7xtK!)?RoJU>12Nb^0{# z{M-lkJs>K)|9W!U(A!2t0Jv+CZ@Wgw2A|lzTI(I8@x@5UKEe3jQ8|fEB4Xw+SZ$Ul z5iJvdGrBknKxY0ysw*O>9c(55nXaak3FYQ0Xu3RN1{BQm^F#HA(za9#3Mo_V^%x=T zwW8y8vyFzuFny`Bd~h}cRIT1JY40Z9LC@FI2SJGWA~Q+kHlPO!oo-@-^#GKi`|1o1 z8!rMjG^4xiaWd|)9^sC)1CLM%lr9Hsxq5;TeOxBDF#^T>YNTBL$VDiJ2#dc$_Juc~ zF?u3m%gT*%Qkd~atE1q4@&=1buh89{v126!L@?gzlz6x0eiHC1RgCtTN6?+(i_`gH$~BA=<}z8}AB zz3TZjY}EI&Xt~YG!IFrzA7H0?dEWUGA%YmIC%k9Qhw{a$pca*fNww&92Yfb)xE0?1 zMrC&|aOZ!Nac-CA2IcK>(U`D1SSuerCb)1$jt?8a1_~H8u#mvz;Vy?amJ2^#Re8t2 zd~|T)a}jr{-+Q-F+E86mU z;K5Q=@O<2VHne3OKz^!l4fy(a{;4Xxkj@dqCIHc!#|Fj~m3XhHq^rN6aoLeyb_Rb{ zW)D7iZzl9kd`d7IxHW_|bYvmB@FgJ1+HK$r|C-XL`p6ITwM3-y&!u)w8e0vkwjItj zi3oFf=u`mFvJ0n0?`|tZXa&tt)jxJb6(j7?ZPjH9ix8Cs@R(+PW9X^5pu8X;TOic( zXmBC5g#Id|fzW)8p2miP zUOa>IsN^H0f+?i)kI-_U^a&`#ETqfyZY@X1rWLkA0j@fP<)1g5xY!D@@e@*e&QBb=yXMXoYfSpzxdH-H2%hlxKFGx* zuTX(N{UJW0LXHCG&kND-%%auM-RmC3&A4ui2^OLy!A^Zc^Lv#q?C_`8j;f5G0 zlF)*u>QnZ|ea@%HPTh-5jW=ZVxA1KQk_CBxbv4CbCYw{gQlTADYr?)BF&NNPfy+0G zxfN~e=&xDOw!&JiJ)-ue3eV(@P%7eO&I#kM&`pjZGl2=vH*0Mig7>~AMXOMU9D)lG z$3MS9A7X_ZBK}wCihQ(0xPL-_Hg8=^bVmMn?I*;@F>i4p=WgQKd&ju4-humPUgOBV zDUId&tS2=G*M{|Srn`Q|Q+zfzQaoWhBM!R66)>%yWk=W5&}?QQWL9NG4SbjV3k7`m zki(L|p2gv5#K_UA&&D?h5&JlKGN$R%Z8AKh`BZ(_DYZRCTuR~fn)OEmSsg#y> zt|>J&Dwmp`s%M`nDTz8qOD;_tu#Hc?9&s!_5kyPV{oZAfc}+&uiE!RPm|b#7#~a&C zn={>LDXXV*>^3Fd8nP`4j6$E)mz8X)&9wp zl>8%L$yIPFhlDxnI^o3GP7Rp&_nyz-(3+P!MuzuLm-qX#Y^^68mC)c+ZGvnhwt2;{!ny-0P z%vSd4t5)d%7{aLt4C)TnF=tS2HIeu_uObetBJctxvz;L_D?T=@wy=~~8f3bg$TsO! znC6sexMR6^^%>%74akx0#cag7_8DQdj1n2F)=eELLR*hpwppYp-P(tI~LS`UR~{x)JjY z!E58x8T$0zJ3ZZ>h=s;1r`>Z(|8zy&jz#qu`lbBAfWooPjkQ0C!(B zt5;fc^l%5Q;*nYnXCw*Hul9j|LqWCwupYZQbxZMnBiu75b7^XG1+CYN(LVx%y|KJqb8JQ4zdgS&g5+z zHh5aVd2rdZt6v(7wqH=;W^N}&r=FrU2@vkaV~Y4lC}x6L?P>o*z>=Eh+g7Wg&VlhT z8}LSzM^e1?4?~7507#bxm~{@6#VUg_4v2p?3@Feu9nifWCk3Ath_0)R3S5BbiYrfu z-xdgW93^d`HA{jh_IlgEt_rt_X=1z|L8$8FNU_uP=b7gOv>xN2f6215;R7@4*8gME__3Is z9OfCB+VIvnHtJh(A3ksOKukdcjx67G|YgG`W6bWES@D@SvSI8kahH#&!Xh z5?tWrTi@&AbWw&3@>*7#Y|Kji!mPa9$>PNqW;1m1O7=6dc%$VDwHsVlR7k)VHb0^p zeL3C9x%Hsl({?!P4mVm@U#2^|@!_*}?-`a(Hh+fFc>*1f1;v*RxHv7&Yr=1~`1_VN z)=|p9Z9CE2{-`Pmd)dhi!xr+Tufd8!JtDjiQ1t_rkg5yJAZxv?DhhG&o`;9MN$~EXDh{)!K zxVpc9h2Os0K-~W^ctikQRkuDA_t#BiG+349LmVd-li{S%hyvq2k*%^pbzzwhCvJ8k zq`#{03{1hu|Ao-H)d(m4G&ZV6g}b3ac||I1z_p-wHXRn#{5PsB)p<7vw-2lG58 zTql|=0ibrE7D?biZ~;;X%A)|eGZEsFu7TiJLLjky{Z1Fms;?5lXE_q;s32kTs%T1Lg9|09d@C|n3&Zo za08CUY2nPYbBThsIK@2D^Mi375j3Q6?uI%oK7e1LaCcu9~VxOz(O?&Ujm=Zq2c{#Ng(kvM2Ga;i$PkW%I4GlIHDQJ06r)#Y_x=j2Iq0|0^^yoLK z7xawY+UW@&ryhO#%hqR`rz!cZ6oZ?SWZEJ1oMU0@s|Ylq%_H9b86puN#x+Rr!ypNQ8ev&~6j+-jq}2z4Zpz7>!jV)f=p!L{R6%vr z<5H(9ZvePqpkEE(E&w>rpIDK=noF|czqXAQ+O`>OQ=#4@lQ+tb{L&+<%=_VN?n z5lrKMliMb->{5lGYxZDXE^w%#()3a76%jJ1)o+q}eypsLX=g19xgnFfxW0C;#o2Mz z`GKC%qz>CIq|OM*aT?@o#gX>|Rt5=jf=BDrGnj5|I$opcY{)a!I7t55v@*o385r`M zpp_1qZ3l8~8tzlak(5Pk31c%WqkAwFQ)l}<%xP!thLOyMXZIb7GeINO!K3w=)g^ii z(#bSUn_@98A)V&+?R%#mw|8ww*FVZ?-#C?N`|Kc{{)u@g;&mG+UiQ$ehR%;7@N_0_H|ZOBQi19(cqZ)FlDx5;#DwAU0}nF4g!MGp;*}*azVI%w$e0k&%gi53#ufWd9@~ z4lGwR$%z0nS1-VJzDM~fQKd{Yt^v+xVtmctqL=39NmwDSa2NSrDO1LW6zq%x5jc#C z0G=n2HhRtW;IR2-tl=)`rx|9& zL11)A&mMBVjuOP8K43Y`%U>F&!=!owV#kq}IY-_j;$E5=otaekjuAe~sV|e!sTp5g zi^z~8!X!ur;Ls^NUQ^3ZyD@TDc`%$)KxS^Cf+onMwRdQqWljwdDB~TDmlRGj%aU~p z;(lsIx8krMkYVv-q&_x>9eaHmNvQ6i-a^J_%glV_i3d7mm$2EdAcBUK|BkhNfXFG+djoS_&A2{}$ zqbI-lC3@Ow%rm#K!|{Sp(y5@dL)*RDDRIJ2Y7zQYKe|EU(lUZRjb2k=hWlD0MlW0+ zm~q{$!ARC(TQso3^Vd(*V>#=vy#(+j4P62Vwo$%P9m;Ex;9X!Wt{K;4 zgVy?@d*od*0Hx8axMD_^K1Jn0R<-SvmTJ6LOyV!eZFs%iAg0WPa`AZ6ZHzq4l0<)5{3vJ*H z8dQ#%W96}Qt6;HXMvm=dT2BhLyvajT9{%i&U)3YWnGM~z8sw25GX?q?-USU3qsP(gmESzQnM;vj&wE?NGvNy2u!R&8kmGB2pOz+ zTI9qDfFLmwPft6}DjXA8RBbDxi9~l)kS{@rJR#=1$qdhI=YV5Uuwkw{6({Y!%vZ$#M$Jq$Mkyb_z#&q0m)OEX6ZUyeNtLn1i%YIDG~j zW87#3tm8v{*szYs6bhy-JR7E*-A|pg1+9m>y1vmEJiHWDG+HF znYl52nkIeUFjAQ!p50i3RK03tynIK?u1dS%mR_*qGiTVR{!zKJ3wCd1eqpf<(XQ^Z&?g+=9Tzq zYTLz^ABWTDp-zAGY|y-MPBZ?%_g5*4frnAuiZ)XzIm0qpB_;w8r;)>Nu2s0XDucWYziA`MZe_u{W z)mRl$RME^2^Sj{!NrImM8p%XRB!sF*&`1sRgqh%MhpDnPjGIZ58f>!$>#~P11%TiM z5td2fvaD%)C|)dT!-+da#iFCV9`Pj8$$m(%nhX1{Y7JA~FM_~rgx2jVs4@CVOw#Z6 z7X0{$(BI9Mglv;=K+q+pL{H<{LYsMwE%Sf;TFF2i6Y+rw*;GhA-C7l?u>C42Po5U> z)uF4vcC3U-ta!V=)%LT7@>%}YK}_w`7^hZVO9A2xazx9tbH0McMxZw+CTBBA5XfRQ zAcmQc&H!Pi_UF!NumeRI?V|eLwS{!GySHxaRye*?{9+Q)L_();6l$K(rmBO|DLmXY zO}SC~^MdW!me`z$H%lJ%-PGG3j31=u6>QqsLEk%VzLo1V74Dd;f0_Y(R>WdFxJmg9 z%UD>Q^vj|f#RQFhJ*_E%QY1$g0GOhghBSEBj#R&sehq#?B7(rX;H z83@igk6I3(Y9ya@*%+>WG-(2un?F2!*srS&N2p%ZtxQ8 zq*0%>M7O0GH*z{EJPy}+r!-D1wC3D7%wB`~##!|WeRxU0C25y1D(W1!?Q{M+SZxj4 zw=pu-C3gqKtuXpl%;)6|!vqKa0uuBSMMZR8^8Ztw!fLd&XXygCnrje^U*)1TgGTpV zD+lK460lWlYtP-RlL$YvkVhPs2rO^Ru2FT0!1C9_(A)?VgYln?j_%44a!_J3FX8Aw zl!IU4x8|phTw?2~-qB)6qWe-3+XIFH%%Z`QE0ZPKZQUrM^D5QHTi17asQ;w-mRv47 zT2*&4xM*^C712Y+47~EfZ_8N@`Rw*5$vd(K>&REbZ`j{fTp6xAX2<)OT0A{-zsjrR z0!ImGqp5)`?Yr)7g!3lS7}9>-oN%q(I=K+xakiwj+U|sGdeCJv7}Tsuhm_->;1fOc|Au{_~-%mZ0d@8 z-j@L7eg2iv74ZQlrj!RmCM7n}!5=~I%4f1rb)S9NGIQRfQgF>kD> z_qo`U4TeVhT<`(9zMb`)%eq~0J|7Spy#t;PtImg*avqIvKlY4|_*{cjZVSH59wF-b z26n~zpUcaxztIhRu50ZziPrI-7OprPeBg#?_IRAz&2p#WAE@r%DZyGb?TV5t;oqjlAAL zT1ZmypshdiQ<%;ANYHAHVzZFe4*cRqa&-B0CA)D#A5K4(w7k-Wi98Z+9W7H25JMIu z{nq*a#|x#US&fZRcbi!OjKo6nS{j%V4yQH7h@#xu07Dc_Efdy~Q#ryg^16~@_?!6M zYM6D9d0v4G)JjZfDp-TCUSc#Na6uM95a!gc;tF=!lq8Ip!tEM->A0RnPq8 z5QA0Gim@{L^7hjmp_a$7P{jysz5BF8gbawUAV-zOp7wuL0Z zzfbG>h0$)6g=6fs0{QMeonuw*ZLL^=&fqkms|vP)krlcJj*}9)nRkb zPiTFk74DqaRvji4lclC5u89?fa9tKznkpg}i4p7y2(d&i#AIj=?c;J#8_Z1miBPm# zXDe=#7EX8e!}8{)@%6?qyIbP0V40D4&e(5LZ((^?XmJvm3qA3Z3;C(SwnGx`N$YY? z7sIK)_6;>Aw&Ida!nI}MaNATd%C$=n#huHxwJRnpEfi6jjg_3$WPC6ag3Ky--5h6@ z`yKcX?aHFO8rvJu$4p%5HUqCa|Q{Hf|Uo2qaLv%nbi=Y!ON- z5%Z}W{oeRK7^UrR1SgP%O_RU~d(9OO)8cE2wRl=Dq&6@^zdgEH^+&f(#m#R;i;Uo6 zMP9Wv0*Wp^yOE}}t6(9Z>@~v%>mFWhkTDmZ0VXCI<;bWcrOhcZ*Ny!N=_}oV59h(i zHIj<6w9Gs&DVPw}C0%QqJV#sq0(r=+z^7`K`ZX)8ms=E=NHc0B&x_oosc@SE4TVV9 zBdws4Er6Ep?@&;i#nm1>7H&y>Kjk#3`fQW}CD6WfE)=5zz=Uv~{g!-#5%#qeY+o25 z2(b}*P!Hf3xB+Z@F*>}DMZlDY?#lZj_`-S{s8aVY(tWj6@ZM>~&Oqy3hqCG6g6g0j zA`vbnwl%mFF695_4tSyT7<)|(b+1;rh0TLNjG(X}g9oi%*1t3rij^*~s6h~@H@M=1 zPYHtE{GA3cL2`$y^{ZPAT~)0qu#LZ_VUx?;?znf zt+^sBo3%Es6GnXkG1w1C*&!pVtjf#-E;&8eScw5sZys=8-3p=jft%dt8t_ahjC@Pa zxO;}iZA5a?b+$B3Gn!M#y0piX?R4U|YOjCCtbZ!Zws}4`$iA=)pJWN8Jp_h4XIlxS zJ*qV6ctePZT(6kLIepaJ30^dp;K9y;xa0=%bpA5#{@A1O=Pp$FXv6C_m<^g+S*gLV zr+xP@?>T-{51tZSi2GsdS-Wl^JjCq{ZqI=C`X|p)1Wy%Faku}o$)pJnnCUvW1AwPJ zDyZacE6t*&hwIm^mqV>b z0dSI6NX#c4w>hfZ-lD;i6E!NAe76pUMQ;`JUgcv`lFP*`Vul#@#~iAMH7dD~Dj)cYxKqihfLLvi!3Z%=RNiyu%l=Z4^LuGy z&}b%}kp|0hU3`9ZZRcamIzVJsC=ce^EB*}7D9|9OG2RM+m+3D?ZPX};TKQ7iIv8fZ zSj7gJm@vRV?Xg__rRKP82B0K@1Am?z7ExwAOuIlrdkqK)F#;H&M2=83p2BSaGo-1V zNPmHlm>3x9WQ43XwzXQF*{1kez$~LCIm4VHgM7G$9*e2)Jgq0}e7(}9> z^J$Gg?{@;aq<_~0@V|KvM_CZPTzwu|MX)zjMnxbVr`8x|1V#1?$RvbFdQ9qy5iu&0MDEoEmVI{qNX zl>y2vyowkE>I4MWV21eWpcN7Ml1LcI2>YVL)`^7*q7i%1>^-OC zTmguE7F_dApUWkOzDN#PRpPwVFzs1b#xl6a-%x=W+Oi%R^B2@c4Ye^IUd_u0G9v%! zLVSKr36LXW!CrG)CCD5AcU+8ErbZuNp62l~yRdLy-j;&^rpKb}0Et-9l5k$w1|90>Y&8AzFUt(r zsASfb9Q0WU`ZR!D8K}wYMQpZ!&rIw&l50#Gc26?ShlvvR!LlT zeH($(e)IYO90SK=J*O zA|0gd6bKee@K?mxzh1}O)%vC@6TZ&kHty}PTIqXP;JdGma9f0b#>2k$kA{J`|Gi`SOiN}99^nGrh;RQnk9$d!3Ab?p#-&IJIaZK;>kA;u-5PdU~3)$k|o!mC%5TPL;wAW z^cA&3gDQ}sYb3bGOhUUHv)8gl|7lACPq6sud<(4c$HjAk??=)9!WG+%i1QAPPtoYb zFY?lpPK(DdjGb!^$Z=B^RHw%2c`YNyf=Uhx-79y&xMU>*m}i>P2z!@x5cpanKURwP zdITNG!|{_(?KQH>%*Z9$9Je4``sCJR3pP5x@|0vNQ&!Qz!}Vxdvq&hhX>FDZtjvtn z%5VH4Mt>T}$!mu$dw&L8nn*Chey<{&*5I`o@nw>ANhWhDTlfB;GuFT8slAwQ z{nWU$rorQFUwcQL?>dyb7V18Swwj+0(lx*jr|L;6%zq53rE~rCD^e~ zkZUmxCyW&_F^%=O^K#r-9!B*m`hM57Z6wq-%~hTlRl9xlzi&_qAhTAASqg!AlITA` z;8JFmbA_B|76!yR!DWP%rx<@;hyMBWgqH-l9Rdr@LB_N*qTA7+`rRq6_;!}S=_Y*Z zKzFLqr8?ZN>RZ1;)@~WV_Gn;Ej$8MIIr{u#HEA3pe#YKfz%DisuA5eOnNG6{Zw()G1*RUdOqaqq=4V)Ven%FjabQ^Tr1_{ac{ z&F?cV#)(YRYvI;n1Ve3Nt%L^%Ay)h8P9BY3Cl8H(fIExnR%Pin5M+`95iCORUnBfI zsM4$Ofc;2sh+i}$Cs?%tUoSck;&H)O4EGkn#XBxdK7uFd5yl2+%2z~LFFZra^RPgrYyH>K?M z5+U3+!ktA3*fT_|swq;7uoc60RwIIy@FZ_!5TrP6B+%zRT=&>#|4Vq38lJ92rubq1 zi$bS!gZC3q$mwDVEF>M7z{MT( zzdJ9{tIm;4%0zZZc;T5k2Yl}1MFjdz@_zH*y5IqdpJ(+^h|XB!iVU0$A_ zh3;L&xW-}9YY};xN{+9_Pq>{&r^YKGxGNObKBh~zo^&w}Ps(wGZ6m@%qgI+iKu47L zRFCVh8*O|i9W;CvvI#u$z57fKub>(mCP3h=X)TFltqmv=-18H zq!dBDPr%)BN-&qbPOoW<6na`~X=Y03U&~OJTBP-H#IvhYNDX|q_r1G^k^6y?w4h>( zONe+CnBic6<*q?o_^xZnSO|6KFR1kySQ-SmSv=1)9NMZ}`1vj}0z@9w^{0?&;m6VS zB1O~j`BgHgmlPQTlG3V4TX2C?-)N=K=fpxY?w2_2M5_Np{n+fux#<-Sp_qAyyCq>! zM6;qE%LA~SN|HnD;Tdf$D`vzk>eFcZp0a3qK+!z6ePl&`_l~pug&AdLczBa<(!$`} zqubSH46BR}vD3rhRUU2+2|xZ7akRy^@MWiGOT;~`L+ELk*(Ox!Bd^TT!p)odqY`+) zCq&(&Tu!wI?@_<_b8g)6p9`N12&acHRkbId%-fZan};(^c@&oO?RLS*uGmL|*#2B? zTVS;9gMJvi#=^9D0PWWpxcScpq%kPluk{ zW*wMrdeO7toV;mumO#Zn(h*j{IQ5|^q|5QKTiOll%ROEB$MG2DhF;VE{*{&Q>+3sz z`XAg2x716V>m#`!)tY)NXyUnyXmS3)z~@zse4pzpUFMstZNpC{J+DBY$d3n*CwD!! z?XpS_`aDV(z^WfK>?#1d+ERtKU26pcZxIct&ol3s8?BjjQ14dl$f-}wExOYH-g2+E zPAdfDZ3U`YcBAocwVf>ogcGH0=>g) zj`D65cAzZ^K*W_*G7bha2gYGlRUcOLP`u-&24SUZ9bZyrQ{R+XX_GrRT)}GP+Euvi zQ}EKAIx%*gG4pC^jbl$q2j=%LkdbotK0OR?ujuZ;2mZ1QtBU8O?ZBFG7ro%gp50J< z;5)ECixc42A*Fehi{Z&XLU>YIy;@vm-=Z9@WS4MZ)pk94M1`HAWFE^>2<)~82;*!k z2#ea`Ede#l&`SGma6!Gc8*{$Kx#?0Cwx_HHhP4uFd5$S(1opMu=XqxcKG9r_yYJL) z4KbLRKH}>AoIB~-wNSW?IC5$=zT0EAytlhcYaZQwwLiMxsMjU)@6M;cN83~Kdk>(F z`3hAr-Ha9vy!@9HncX3dug)}j>HZ^q#^;NSBxl9s>*zB?zfX(f#Bj=AIFesoB2ax$ z#mk|LwHPWkx`3%d+<=hJEA!PyW)lOt^FoB+cw0hZ33Y0m-xlht|9f5Ul8H2;n6c?` zHx@&VZ4yMu>6?Nox=uJm3Nn$jNUrTqV0>6v#}9&0mCRa~eG3%@vsr7~6bgrkX|M!p z*t|}R8Mu|8%`}u7CAe^CEMAS!q#-)$iS*yk$`jjkK zXUulcl}veH&m68$_bTp_s>}`CuuUC2W2MVmzv}iv<$)y_UC%z;)}Ufvd>RF5_2 zR-lz3UnN!@d2+YGOgy~N&LU~JqoJp$Dghk*skEqyEt;BL*{|p?n&gbXlCh>YvKU>1 z4-Xb5;4);33NJ#qQ|F>}v07TOv#<2%;+jYsVH`SNYfi0E0L}|cc(eh63R1#)p132#q(*VcCOb1L zi5o&G74@3w^yp9ygFLU4TQB#Z*N$uGZ$SE$bDZ?JNkUYqknS^3`~kr*EB-UwL!==| zIb)nfgWM$ppq6Muj0lxI{2V>RIxmE%P|Mgf0A<4s!3fkV0BI2d_dmlkI;BKffyujA zvE*A1#KZNlhE&J@J@A?e;WBdp!ykc74|36LP9`g?JA^$TRCpV>_#zRER<4J;WkYv* zd~x;pxRMb>jYYVfR^ zXZgR+nXw6^?M=f}#l&x$TJ*j_gn3V?X*g4lD^qV8rkyQ;zp;bUZU`?CLZV`}J{k5O z#@2X*PDh(N9J>DVkS(F8Xmf&2QSZx<7@bu~%zxQert`ETOS`u)Q2_4Hs68FvaFRl` zfN=o?5y!o2wWObL6acs)^}X~ZMcIlQe~z?CDKQ-uxBe-&MgFq~V#mMSmwb2_w-L;S zuH{vR%R*=#6zIyQ{p+aZLO4k+bt&go@QY!_0>3dAA1Iz%2cz!u%W`A^M9a+3s!lP| zw{N;~A7r=jMW&qH17?V4=ID)Lc_5!xk*XY~uQV%u%9yT5D;_rg&V3XV{Ta8RHAFI7 zf^>RmL}t&ISO-3dXnx5%%}VSr|GC8YaJLShP%U6wX)1H<^vdGb53l!`jBx6;W%^vy zM(dj7-gV2?(duF9HLh*omg&lPOr*j+Kl}`HoWz>WVm#&7TW8k(ZhZaDTHQFRI`=Yh z{|8m{$~2V7)x%R(;^OeMxCvs93dv?_&?%~6BXDEWdJ3}f^b8NjN|D;lDB+iT#H5gD zC`Tv81ub%Y0j|_+NtnK&n|FDJUuH9uj7<|rs0US28xI%kI>l@HKJa_xgT6UznMdhU zB6dI1<3e9UO%^ENl>!eNXn>ixM_>=Ys;L=@F#4cDzMhkb-#h3M2Og8D1N5`F_j;B)y@@~pKgFLPBANjxEJxoItjQ% z4tvXBVPBj3&`KaOS4uK*uwu*C8dd4ny_rN2cc-2DO@L9Wnz*L6Qn?lYlK@XG3iw$$eoP&xK|@8##1us4&6#!<-J zmGaDyHxIu&Qy$*exA%I+yBY~yNqODh@xCDADi+&DrP;S`xEj8bi*;%2$0?4E0) zr!RQZ7Bl$A(OZ}+fs>$)@)Bn9vmztMIDg`hwky~Ba5dXQJ5D6+D03pvb*`H471}wEQ8wDKn^}rod&N3u+59S0^ zvhDx69l^!T5=rJep!V)rup1pTOKr4g&HNf!T#4&4P*@8iGe_N;2HcupPTX^w4l$yd zWbOt@fD#@z4lCItaaSOIu!p&J1MCu5DhybGFN-U{$gdtCdq`eihf(}au*EPdL!%?O z{~?Bfq^X?;%zaR@3IewiNTeD`%Ym_=MM=Y3?*LAdJKqT`3^-&5)|tU-^?-LdEC$|K zD}p(TC2k6du@=JOLz^B+AFNsL$Co;Q{Q;4PalR~2iyJa+^~jSpn^til7Rndbo4M69 z&u~{v+b-ov2l4X1hexQR{DiV_ebCiSa7G@? zL<^lOf)3_EufkR)PDTd5&jM}N(+(JDa*NlOhw!*XVt@)3kO%x70?<`9msx;ui8!#S zY#&DMwkS&w+9oOxf+j?CcTD2o-6$O@S%t#gv=7V!r@V;7MP(ufy|XEX5^%-ijBJ;{ zcITj+eOaEOkRQ;V*#|);h|vcI!3lfW5kIu+wz7TYNI`d*_W`tn#-3WFz-v&@;<5r2 z^^mwM9U`}{yvNtorK!qXcB3~8{_4G>EKPI_2layP;HCP>(Z;sBbfqbJMQUU_?T}|0 zSn#F>hmp+x0+2hE@<4yN@n^C_U)`uz9c@QhN(nNx8$19Q`5q{>%bZZtT_E< zC6Y$o*&xEG^|*M;9F+eZY9W?vvmXV~Ri{Xal^?D!0cN3xMUPATxUi67*xo95v<}pL z;em^?%15xpc}YxyiD?7WEeu91hdKC(N#)llVsEkvuyCxK6QH zr*+odAk@s_E}FV&BCW3Idi#};JrvlTiK_eW9#>7_?i96NYnd5cbiKxj`t=a*S7l^} zNA*)r{D>hrTL6AnpP#5%F{%@3Ur7l9U?=a3R=3o8)tcfb*E!(o?fmVm zjOtL$NJ6dLrCDAemh)PWDF?1UH#HDG3WeN_YfYzvlZPNeu@y&#;F-yxUq4 z=(|ckR&!F~mRD}y4LJD09JEqiBr-x-njk3KEJC;lWZOlkFcB)kaMBlpSUV_9dF-aa!Nr-iNXZ2Lvn(;10aCzeJ9Beqdw zfqa))+e^NthzPY+1XmUq38(Yn1AO3)%5q);1ga!nkuom|d`)D{{v2QhbW2V6;5b(5 z%ul{FgE(j|_a2vencBR36?~-_Vxp98YemE}O@o#4WK(%Sw=Ak0rayy7R7+~Qbot7= zH9yLux>1%+@o?ua(92_W&r)^rRqEQLJo$fD zkPfxzb#t7E4l#8+P3d-7p8`i~~W0^UmP+1;h09V+m(mQl^4BnKo{J1U{1_(Sr%> zWvCsdqPRs8Feq{Jfc-nVB;k2YS{lS zI@ue-><5A1Fb3$L`!|{YK0X^zzQNhR(Qzg2m@)6=BxfOPYE1Gc@tV8GrSDxZx=?br z8}3Kx4OgO)IN8CX((RF-7JQIl+MW;$-o`zp>t^OU%siKg;CUdN|g0!ataEy?GqB zPBGIzQF*f#ZYX_qBxvGVWt-nVuucr;nSvOcT<@#42nRop`Y&#eL z=x=#k3G#Y|ZH7V~Ymg=8$yY5(OkCs{`qF(Rmo?e7Va*Dpk$NPnJwEVJ zUZ6s|+byo}^)gEyH94=$KVd_vdYx5m2QTyZV&@cA@*zS5?)`(Uv^kBX~E` z1%C3nQxqeow_GPVvE4ZM6<60PH^%_r!-9pm-TGD%rMB*QD>=CWsqtk0;Z{x<7ZZ+( z@vk4+@d0Ngl*5O*m7I@KoZBoz^!)waQPqDs8aDCBKXyf3d zb&14<4|fy8_S=S8Gy(ofxVI7kw7}f?FuGo{ea}Dsx-l~f)?yqQB>sn{l`su}h3cQw zFbQc)q;YdsCA937lj zUTM21T5n3J(*Gh2%R5CjsS5e~BsWreQtiIL0nW$5KNMA8`tT^o4{2&D3sOqOo1wP9 z2mi$d*0z}+cne**ch!&UzrDT1S#gfjJy_KY+R~`J58y*lp0pAv-pzD3tG30s7rlCkM2Wz59 zJ1l&b6^>=xN-bQ_*Rm=|oT%IxuHALaF~{gSgMM0i+jzy=Pm+zT+g8?Mz3#||CQo+E zF~7jGGt1DNyO!?hb=v07uPUkRL}hj}92~upSC@1zd%lw5zL`60KV4)mtsw^MhI*CfH0IZ=jye*y%C?;LFxysQYFB<7F#W=^ zm0)(4Gi3<7G8g?YQyc`3Y7V6i@t9?J6mr!%tu@*WP#pVW_c#+N{(zA zg0%yCXN-6BLho$esD^~nQwjj2l}9hX@H(egEWtR&-Ym7;vZ#q-OAB}i>%F7XW8esE zVATA;3~xPMknFjBTX6;y>9jXCWO7ZHvg4lfK5>HFDY;rQ>C&cHea!n@eNK#H1E{Y$ zwZ(pkW|w$9u1M=tf5EZ0sDq-sKh`H2g(ltVI2!yhc}=7DU&6oc2j#w%H3VMNQU411 zSR4}Lq8r_I&+VEJ9^<{DlHy&vMV%}3L|yT%s#U?;uG~xeRNO%IxmNI3zuU9H(RlBR ztyytX&xc_6JL`btXrE^yT8z_&QOk`^k+kVK;>Ln$B=!5g%y5$!N|Bzrf8lZ|DIocB zn31gqG@Ksh+Abl674Xnbv96jhYMSfxnE8IKr=T>9<74~o;*a7{dKkYb>_GI4QbO|A zD&W)&i?5}|QCuF9v~EcsMh=E}j*)pmzM*9ACCYE))%txSj3QPhg7MnTHWU>#^?8J~ zzPxOdbl=|=Y7`sBLnh|Ii1TaUW;884DN;hbAx3fxN}Nwq9K&b9(hRbQsXq_1`oNUJ zGBtQ#rG$k@DhnD~`57PVC!w*%c{{5+sNH@@E2?K@(2z5wTv+F%&|tP(?l6DY6ye?P z$tqO9D_)FGuW8N1hSlRJ04D7CcW!wo|8i_8C&>bn-~hgC$w!t$ZC99>HUDOk3-u$- zIz4Kg(y3;bV&rsgm!o$*H0cm^qjFN@7Ir==u6$Gdsc*V-!FH1g+D+M~XVm9HVTp?E zH;g+bRTn(!HAFx^V{1V#br|FsNmMeYXn%~aqCMf*^CXK;;$EkC&$J{NyyfqiE^p2} zDbqu0Hp(nr?Ro)W$K>J1A?PqcdpZ*Xr3jf7?pj;&0Wcysm2+=zyAoQVzP&@m7Do)0dV^X^~PQ5QlIBVl@(c|)J+%AKF>4p1)^=%;Tn`6 zUh{CDPHHvKbO#Ku)*mf6;}mrs#reR)9-Q4q`v^fs_oY1CY#28AWpOOEW_82y_#6u` z(x1?idGwbbS&moy5sZuEq3Ia&`OF~S%{3^VYGg0A-AwpY6+2zf;TC zkGA91@nPFe^Whc+Z`{2gBzmU-;W)#@ZPL9Y6+qDN9F?k}m z#c1Zu`7s{35Q6yFJBKxWCoy7W{{N0A>WMSb-1iV9CtHo$sSeHi!L$sVf1hmdYhZu! zz=X~6jJl`@JA_d=+YM{f=pnOp{oLu$_PILei^`@ynH?h9Q7vbN+4dkJ%#TmR6q+VA z6@D071w@gy?@nr&&FPVMvcUo|6^uHl!C5VsSaqu_0tD^2desPZQCU&E ziAO%T7_zlTRhn)yhaEgLVzrztV)zY~4q@aX83* zAo7AMmxoI&8Me|bLa4b4IGtu&?6VkSq`M$RO@G5Q*$>{ zC;)n`IONU3qx3Oy87i`yxA~e#9)QWk)?XZ9E-F+1ScG7ryF;v=D1seYCAht&V{3Xm zaklx}%!?wzR{E&Z&b)DIYzaqZDF_O8+Kknfw)rl$Tzo@VW!)cd;&&2+DXsK?)uP+8J=@a^%B+s&#~W@odIC{Qsgk9Uxk4?op+>GxOb95W66gXAAX1u)_d&hh z$~SW&^0~YN831Dmg+3bcO9Ak9FR(r8m64814B|0#T%8l|F?2oAQuo7@m;`!?94IX) z^?5LQnG5WhC`0Piv?=v|w@B2u8tsu1POfz`KEj@_f<_5 zt|yAch@cZCt#`6zCjoXtgHRcOk(h`>x zq%qflJQ+Y}huPOxS~7>32#FJ|%5EHrW8FPoj5P(eng&;tMb+9sU z&rYRT)uzF<@bG>RY_Hm?Un56`*iGqI2y3Oi6Xy0prFTONRw3R~2_ekz=*8Q7Uz%4Z z-#1@GS}sC!ev7d#=G|6?SmzD@s=iw#)LFUenBC_rv{0HJFyR1KnunS7!%VMs`D-41 zlxJBDvlh%G)ek$VLPx4YM#lN<k31W`NnNbsz5)$B;1I%H#U_;Zysg=-g?ER7J1xXz?U{cltd z8x}Yik~TFT-#-kBIO|(su9l-NJyX+NF}V}J5oAx%C2DjSgftEzec<9|tRjwx zktOQ+$Kp<08SaA^mn=p^a>+{Z)6EM-pSeAML&Dr6iNz7b0u4zoModA-*>}Dp%1Bzs z7NPoa9F$nxbMPyd0fCagi)nn#-pNNspc~p>lT;IuM+mX>`knc`(pt3rM;z=}6!u$N z?f*5u7MR2A6uJfkj9JoDMwe7r!B!X8tTw!-7h#*kSoO=--|i(?jGi8y1U7#$HD8fW zoK$aZ(Hg-AD(W(EE9^H-33c!+T(@iPNd(mf^M*$M;Nbu8zl+o6NL=R(uANnRLG;F{ zi+6ocj9;vjd%}u#N8?)^@P3o-7lme>dYt7ZJY_pGuj#@dfC=K_ITAYs?~FI}EVLc3 z516o_Psn4b(-~H`N6et$60{>?zwqv z=ui2P^qEJTsg@UyA64ank2<_n?-Z**b8Vx9HN#`q53^$7KC#1`A&HkF#4cde#^I!i zs3C1InKci!{@a%A;beE4aOb;i7|0o`X;$8#)Iwp%k zBl1GEhFBkEh>PgO(}K`IQJ8l&{H4bW0U_CS>E89DW8BeD%Z2+?{`L*KL&#hxm0xzy z$o`EOY9xd*4r_>F4MWT`DFI#^KAIHsj06B(45iA%o9j1GX_#LM)>A!9808pi?W%?6 zF^}WO?JpyO?rV&LpkJ9#@EL}DCI>NMlU8e^Sv7ruOCdl}M4*I+E)KgxakVBbuMd6l z8zozeG^k1A8-Dl^LgoMjuv0y`6FrJTO>x8T-&rYvkaZA}07B$y2tg+Z*&4FegQVm3 zuBeSlAe1m3gQY>IxW%&b7Hr!+)EC#@Sw74Y7&c;|-4?;4>yg&&!j)=pU?|1ZkyX1en_`pzvFrSu1eEi)$Wv(T)=G#6f}^ z;$3FN`25a$a4P0H<`#yT zj6CQrCe-pSFYpeJgm_>?(U?HIOlMYsgs+h)uvqMfwz4A7B28T93Yf>JB50O}NBeFl z%}*0Bxp7A}kHO@{OVkXpm;hNQnMHvhuszs^!ql85J)pO4VPxN-1VPWQEj#S7Wp9pp zWy9nhl0Zy$eNTwwBEs%mNP#Rg^dQF}lw$Q^kC%iDZn*h2)Pm|@`vY=s<6-^@<0-Xe zG0Oxbw@hE5`M0Tj&1)^BGatH~4YlM$F-u~G2T%AGV%)2NoQF^#l*$EfsxAycKVR1R7V)z>=HR?VskvB$qF(Jc#67GMcm z!5ijy-I^gDHqW1u--7Gp{xG0k(>-qI3L}_)L~+6A9Fib&u~J_f6!E>Hj2GSj>_v^0 z%*B2Ipe2Ve^#CIeFedD_hID%sowF|A^zu26o(K5(=lt|UW2ujMedH61Fzh6fIH_K& z=_as?Q*5mB*uu~^!`jldY}ZH$OEJv0TnsINP6SyvBD^=<{pz+vSkH|2Q4BMuq*arl z@gi@Zr7bRD^9c%z7*DxnhmYQ7?pS&WyZ(}!P1De6 zn#H)=MxZp}cYz>QPtzlf#JulR9_ld9D0chzgFKU!e&aVyfBsi`7vzA6On_Kz-JXcC zX6|hNcAFZ@M{eX%g)a~6fbQh*7{$DucL$6Lpx8Yc^mYv`2w;>w3155p>n=p7i(AVoqF8*7ZUCS>N1zqcPFs(+J|ykE{%FU4NIw&`<;UQGRT`#^0_*YK@SV-(A?<7BRaeGDQ=Z_4DPI6_buv_IM>1eq*G6`tKV1% z&9Qmr_|~(Savpy94YLz!iCJyMA8~@nP9;ed;_+?a$4+?{>j@oK(<|E z^KE#kt=*GfdD1)n$wG73ANQ5Ksj%~KLmBmY83jc@`3irJ8$l#hk?KidyyvD9+j)o;d4XH*~mJ zh9)$zBkP3;ePLKKJGz{bP{;bG)i1J#e?CDver(peFL^EjF-fq9?i1V@MphY}%3vnU zXTj@SB;$59t`Z%z(l%Gk)7a$dcoK4uf>&y`TAPW&x;8<}32e1(x%pZ{CfXrSZ#!(6 zryQ1=7Z^k%W*owNrP;y-X_@&EtsT~E=D4QRLOBVkisFD(DEUvoRHs|AmRHF>=vRdQ z(qF#qDEP_4AE|JfFSo}z&x__`KK6hHoKr*#tb%pQz{3Z1Fr6^o2+hb*e?;vW3hz54 zdxQUXij2?2ne{8UIHyqxkHT9}gd2aAPKP3G!ybl^JV&W91pfs}2<_t|O33#m_+##O z{#SBxLhwdPIQ=&-3Xf2tz2clD^xY~IMRt^ZPJ8@K)zDLv|?`1ts#&&xdNQJ;sR zl;=B6{}}pg^}`bTG=G>UvX)yzJR|pbD6+5J(aK{P`KpNa1RokNyX9b1mu;e91TEk5oDSNSZcD3)~~2V z*|qgwg$Po>4emls)1M{SW9(~`Y(>C)ch*tPVxHYGZ)$TkG3b*AJ2&`-K~m$~D<~(O z{?wgS<+iUTGWK+xHj5N=&`!JPXUOZR<2hgMq_!Kn7n+k|Jzv)cebCFOfghSWX#?-H zXXzc0(6MzfTQ0w%t@7J}q$frLUDTbEyVtMYFFb#HO+h>$SLLt1QtDn1aQ=2nitw$B z`atE(W@O}Q1L zkkh`wtSWOQ7luzAgt3{QG43J&lL1~|9RbG&qB-=pN44m+1~@rV42CMjrE3jP{If+# z;If$F6$ZDNQioX?xG2st6w3m~ItGxi%CpBNN5XszVrsrx9>C%eQ^mu}1qQgqaykD0 zl7{I+T6xeyJ86&e?_cIcGD2iKDNz??-Cc@s&MH0&aRjC6w3j>=$Xia*lW8NGl zb(4hUe7V#uQbUYhrc%y&$g4|SQUGS7#iIDQ1 z4FCpVOK6W;C&ZZn9x{(sBVu-M45g7$y266L?Gad2rXE?Ao=$#^=1l-;lyq=IFefJc zTMwZ?6xz>xQ@Zuv=(LGg>A;lSEyvqFbZp)sKH*^u z&W2;9MgTv|yLJq0mhG8N|)g0{1OONtEK7zcDn`9YLR8~F;kP_42_67y!_}W*bB2(GJ z9tfJH*5Kk6yR8@dst+P`nD1kT z{~nP^Om12*d)z(3T8S@WL@l(tmnaZmS50MN{Z-mt;5Qf5TutiKh}D4~$9#lfHdH-q z{z6q2h#K*~_T;MZlkzba9S^@RdmH0JlYDc!QB_)T2VH7LaEh;`?_832d{o@|qA-Bb z*~}gL_^>a^6Ua}kXoRf@px@9Uf2U-Y>{}yo=$}UhM|PYWa2jUxCRtb*)F|&oo*_>N zf1lwo$+PSQ&uXMPWc8rfxJ3J_b+!hby3C{M zR<=pMl>x5U#8O!J05&q*m(ymyz*~xdtwxX_{)XDsfrSDih9vkm& z_*6MPNJ#nF;kOH`GY)PV&xfA9w-Hs66Od-e*9_2%$_wMqY<}HmYO(B1+!bWsRtch- zEoPWZmYvaD)`<128EVbcaU_wV!Hzz@u>F7e8!Jr1daKRKJp{4a-14B<&fivGrb<0& zi%4c+4lVO-3ZS1YiHb1u#I#_BX~9ROQP(JIDnUdx-edo>tj1dJWF9HP(>J{CHYRxH z18JuqiZP%XV0~9mi`65vH`=QK*_ug`qrG*_Zfs%3Rcd#41Ut62Qj{(=-mQVVaf&LV zs!7h*6r=1&&CTek4ry zY-PddnZ3RbdV`}g@k6ci&U-&2oKyfYT}QH5F`zf!T8n3?^%&j6gB`#bnFULO(p}da z=ush-iZvMFa-5M@CORr?9)G|vJoT%u#rwm++Gv9ZwZt&8@-L?y_tgJ0wN5Erb~{E& zD!)m6qd+1L^00daFxC(r8euvEX%dga53M2QhNX1}jy`I$1JmM;u zQ-$G`O#A`_+%d+Ug^=H4I zU8(iR^j0MPv{Um-CY!=DYMek$zRnOGHFC8zdO7w~C#F>MTDmjMhsTKH?2`|LWE6^s zyG@7}>Y~Mz=zFJ^xlhb*6TO_CF3O2vN~DswiG=V+KaBxlW)fkn~UoFN5%o+J< z@Fn^98|t!O@YoyMP<5a?#>D25f1E_)7nz@WxOmVw?Gd6&^?^ah0u2z{Af%F+Nxi-E zVxf9nQ1920-Xlkt=sIRxJ{Rdaimv-p7FLcH3Dq(PdTln!T?y+`qqg*-)-Is!g!mev zJXee=R-;%-v`CNQtb$jA2XFlPpxWUO7nd`XE5RhnMV0|O9xJd|*0h3Pn`)VI8$ z^)r2AZ{x?`M-MRZbxe3a7k>(ZYgDrOA^6b+Qm;PVuEXR^L^X+z?^K@aS8ng@od{AQ zJGjKrP||`rp5$~^HC5}?g$PI1W9BFnp7~(A^>Vf4AkQM!_UZRAs)ASNSCO(dD*3yf zB5O}-NlN5&fvQ0=;$u>|q9LMd!`k!DlEbGg@1ZJQlZz)dp6VhT3~$moHBW^%S+d*A z9f+3Egoj+C$q6&lOzzl@R#^MiRw(g5EKxu4bX5N^T1*ayY?>94tFd8Yy_B*F%5z92 zU)gRxNu&V?ff`*Ly-BYp8JP4NYR|9w9UVg3875_=70NKlia5_?>%};OkkZ2?qB4n% zGq^<|Q7oj7b=rMm8bxSE`Fkk8j79+=XJexaJ}Ef^D6g>*J5V~zUGg_@iSLAzi1xjw zw$U6k=BK930$s?r&lZ7#-oPZa6|fNUsH}FUXuz0nYZ9p;@iSMS(l9_)@HwCg3EgnL z-Mn98#RzrJoul>5(Ozn-RCX5cvi=APb<5K*R2onNrUhj(tmziXXBJ6ctUrL9uhfcb zYmvJf%N6?dfVSu`R`MZCp~NVd*Mr396390d#J&oYA-U@!sQ4jf{H>k6O|SZMrRYvQ zwiRSH)B(v5aI-gcq-+J9@jheoUDUws@^U>cU%X!xLQP3rHKa$A^3dh_ultx@ds%4> zOiZpgyhM%t^d4>437$Q4kPvfU|1XHH&8T6|ccZ*|Nr_7A!YEv%j`*?@Mao7U;f|*K ziJRRU8)uB`)MM&|4j1gP{Cxc34sZrcey5IuUo~;e-0)tFA5xVNFeYP37~8;FTW#|Bd#s7NnMe1-REeC| ze6N)*imQgEoZH&jeh+1vOwBxz|CaxqsZ=V+V9fnrZquy(^e=I3r4zho%FI}b5sKHA zcQHdYfq4A-B+7Yf>_2rn$f-AP2}kZhqq&F_NEjU&b}V&jGB%^G3%eSH1bf(8=geX? zL=;e+-(z}GxTA5VDn6eG)A+p<&-ZfI|J)N798Gow@_uK6+EQc!F8_%#>O?=jUwQ6j zxz{xz$z2n5+U_rr8vRQ)X+~)L7sPy9KJnmXQm+GPKp+2>YdkMBkKnn%@C<``cL0>} zrrqkx{O;&;PLXre2smRkeve=Kg}nAlu4R|JUeUMk687A;%AIr0ZKs3ASd5ve+N>W$}BeUu*rJNeJQF^oqKtkp91_Z(gzjm;ThSe z)^q(WF7N^-yikoPoyHI7ai2=5rEolJ4l6!GT4ZAT^dLZs%@<-P;a5!`U`zG5$_CWo zKR1uuvAd=;F3;ZrziMoNFO9FqHSI-xTSVJoKGo5Kl7%B{mEidr35z7se89jvYBhwj zT8JqXvz*#}!HJslFnqTd9l{8&SN}7)vtrXKQa(PA_CWS&m2aI84^q(i)mRjC*9|c> zvYmR+nal#_BQ*GT;-xR@kvb-B@&dTrTo(IROe18atC7nt|J$G1QS)FQr#Kf2Q{?|T zx_yzE11MwLg}6&B(o*?k2h$JBZeT67p4~x-&7>@+Bs~AH_3spl0!ZMFQJ%yGtz!po zhzc~5W30Y)jP1G4+EfQNY&*Y3wP!vU%(U$1nRnV+z6&>RZ4Z1qmF|50coU_rmB06q zm^>I|s@_35o--qRo3Q)h<`#XkATGIDbCO$6oP?ahoE1urZsY5T{bK@H4~0Ir$JZ&r zI-C^khI^^TAI0Xr%OiHHC%>7VYf%&RkV@we5?@2#P0s7!BD3Eazg5-Cyuh)nt`?)HOOq)Ws2NPrOw zWt@^0T^cSfC%vDXbbUAMI$E^4n6c{J{6dJde`_Z7jolwY-X$O2ey`Ak9dk5QJH4RY ziZo)ymF~%Fx9Xcq3Z*l$%NYDQs~5J`)RCf~SM+#WyFCy@)A~8}=*LPub~OZ1&qXJ5 zk-M2CL(5U_d^A_wbWI)1--)kQ;)mBC{J_Kxafzhz^aC>V5&!kAT->}me6kF^n+ZR{ z1pm~3pXK`aLa7G_T5gG{_1txven21QmMQfZP~P6%i8ho~?bjbR+yUt*3@ruY{vqSN zl4B9MP6a_VG96b6n51GzXvjdEYUf&UCvg7(>_B9hh>3RSM0Hl7Ys5aMgp~ev(!t)s z@7t&Y%((DBqeF~w17dX3W3=9#)PfjS4S+9j%J&~^CJ z?Rp}cz#3^bcCyn3TJ&s0VT4HsHVZ{6M!kH;aTHVk zA4O*#mQ>omahAg>n}DVSh>E$ShNk6OfT(C@xD}cenwhCJXj)cIFXw=YmW69-R%m8q z%b+=pHO&mD6_qtqR#rBHWsNn}&@%JdOt14hzw6>3{^amn=Xsv{d*7e?*nOVkk9nKN zO$)YfB@2*@<y5_K8fm&Khh13?4d; zGni0|_4sE9ht8AuGb2k6@c#S8wqW;#!XGyhFZ>hx`b(o)aVc@`_q~rbxPi*B%L;lc zdVwdBbcwh8*>7Ju^@n#UD39lwehSOnqQGKFIQ(h{m&;hnG_JLX(rlyXce(Ls?Q5P6cNGZ-!-?H` z)IV}uoAB`4dAw&cr^-$$Izx9pBPC9c-F25hD2!iTg6^0YoOYq-&FJ*v8d5nk)6jEF z3sG7m)jdWtC#H937we8KcMeZ2>ka4(k!UA}az=XV zQ);)N!kZS{YCkX8`ipUD#9McG$=<*04Uenq=l5^zNqjN#a^>WOVr$!@wY~e>>0Xjf zVMA~1zL1>Zo&+>d(y^qnGSNT8$9g10{cK3d@S<=+i^}-z-M!aYeya|Y?YgVE&RDqR z%$bxfD8ZaqAmrjwJXtP1180os>vdczt6<+5GPh zsmW2_$j_6P{C6~_ZlO_e+#N6V0Ren06O=fX&FDv+L)rbKjQ9t3GtPBJh$e^6n9R84 zpM)`OA^e0FjEC-B5Iugvc;&E6*$*%i%Nj7uvd8ujYSM%rjbFv@lF&oOWHIFQ+!PFb zN7?)J^VX-|xkBdg!BIwHlsSeb8q?A;4RooBI;ndSrR-q~$$CX;|taD1g zTQX^~iuy5(m27s6^BqxFEIWmdyuxbpxHIaQ)Xa_-Wmt1!$-GHqhi|K#uJtPI9j0_NjEdB*ck+CsAwQ-(3-RB^{r@C}*Z7)VTamHtrs~(z43K5pqnZ;VoVna<||k zCHl+k#-Eo=7c9Qzp5G9=dVZGMhtObc=XjG_#XmXBUou~;OTD$`^Ay(YKW}ovSjBHP zLHLDQTD2{$zh6For=?5#B7aj0)~yaesZlF0-^pS|r{f-Pi3pj(Iu&JS#X4c025Gl^ z|HGUkm;04IAHyyh3eq~2rU&h&#~V_vn93Bz*BHqaCtqDHv312U%Z0%;ZoT{x9lPAP z^cu6#j*aC=4%{4|wi3>D%j;JW+0}kN$FR-HXZc-r+ znFaFDJVhoffSqKkY-kr5=yV|D#6$NRQr7d~rgzPN?`RICumZ*0tE-yhe-RSX_2#ZR z1d3)*VKXJej&(g%EdMTt4MX6(RvF;S>mnXc5mP%=nw9T=$A_DosSj)#&bSFJl!2H} zvwd#kwya3&and%`LJ!&$%vz1%?7v?X$TvHzF`ylXY`qc1s6&ZEV)C|f%+{|JxJW}c z5l&C@C-k_8HW7V5SmjadE@+B-@t+KG$izeRz4oiV#*=GKK^^X%b&3#L!QnRPG(J&a`4f$peZUU5jOQyNIG7x0+`;m7lo zwQY5H=hGx|+KhH0Y*vCFd43B1bQ<`zwlXR9!7Tyv+fK~nEg zW0b_ttPX84jW4_{#D*jV)tBhh>@Pm^m)~vGYAqtaD*4TfM7s9-w=2J*o7c*(ysjCT zMbb>OF>a)X)k~U1f(1$;>v3!N`JOJnlJ{yhEu|eggTyB(%{eNx##w^I_o0)S756j? z_)`@{mJ#0>A?DY~9P(~^Q5>}!=a)Sgxy>5$$)LdO_WOjtEIK<(89o|OhNL%Jhp9um z2)@PU`V7mkb2AP2#Wq#{lsU}yX{ZW6vuJUNAa>7!l-ivWYVLov=a&Cna&V{exf5)p z^i|wBnkws7HQC?!giRe-A)R|he`k5OkQCxZXzXUat%!sBvSRjV3aY-*NqpZ3>QZ&x z%n&GcpIB30s40-V_+^hP_^sj z*XRwSDoTV7=_Z}p!!3Aen5|HIBrPUZ`VZ?Lhhe-XOoVvd1hWqTdCefO3dG2~{R)kT zfd_<`PDUqehNCS^)-_ER|dNDmJNLj5Y-k5&{>;QSj2s2N-=GW zQ~)}xF?o4)Dyssrr?A;=F;3@HnE5}^9TI7W$OTq4w_p|>u6@>(uJ8&vb{D7h>%R6( zfj%g5pL4K~AZN&kR$_B9WHV-FBSORfq|*OlZWS+gsA5|Y#I)`a){K&~ZsGo}M58ye z(HIoIWYK+5@@?J_qsC8FO}TUx%{lO}*hk|{l);#U((1QVi`fqxYPV_@q&JYp11`Lyg#vt|5ZE^VErX1U8X@|7KR}Sca<3zn>YE;gZ8=6-#cq& zvdGDbVHW*8)|n|S*Iz@1hS6NM;xmLZ1zPFNcbAn`vvHM1|+pF~LK$>T`ZVL2zaE*}AW~5_- zm84^LDCrg%eO=qAZ`d6|YSf6cfO)qvXNn?EKhBz@H7)@pwUmJo?gTFu>_hGgH6a`% zkQmo5vu`r+R9|j9H8w*Z@P*OQ6;sD(SM>gXN%EM zO0-0YKGleH)u~`em#YB1+=LG1p_dzhjd_XdY5s{u|FWxSiE#50nWXpvy8b0P&4_Mr z-WaY!t_gA+zKyuM z@05VWFO%2VHXYZwDiy?&dQ`k|WeFlLl81?@W=Hcd(Fn{c9j;Ir-19BA{zTCyNoIu% z^G&|YHwGuQ2e_%w@C+-|6R%RFx;t+3#xTCDR^Hc>e}?&63X<%-W2hs8)5WBugr0T@ zt>%%2RiqG*6uv2a^>)X9=3bJlGUZ(XSz3j!RVix)%o-a>;Yt2pONs#%Eu}Zi#%m$35u-;haqjm*%zrYs4BsyXFrC!|j0?cY9ZkK6(5-ob25u3!D zU#P$(=-37UxtT|Je}`~BgwUw)Z=y9^zr%^vV_FnDn}pFD;|V)ul&v+yfN~!=wW&r71V1*#R!ic#b-aL?s zSz?O6BSDwjFvzw%g*DFwfeNq!2Z7DqBDAL#NHGCiJrI<_2nqfF?#7rmFO&>|gr4IN1r zYCL9;XqDH=adctu0+|gdviZ%t@|E&1nJRQ7?+Co{a`l+)({qK}LNGEDdbt3Ti~vp? z$L=01T&|au3vvHnRHehIl{kZr1>NOj%5Z5)+-O7LPJ2=bFW&neaJ5UY%?SMM63=b| znj6YWOz2$_bf7cNy~PQ$*wOn>M{h5Bt8L$10VTPQtS=yT`h@IN;0Fo_LGK;quR|h2 z(`!s=A`HE3fCPJY8rQyj@GqtE7~1*Ea-HFKD$B>Lq~D! z-N8*ZOrtTYmDXJ*z!-J7T`JsmT{dQr{NCLB`*X@Wg?ou}T1Y zBX$x~LO7y?S%?np_J(pAPQR2;ro$XJ;y>mPNTq~xHRL14Ww>sFL4W9@mH1Yoz8cj_W)>38){9KC!YD3PeYSGVmPGu`(U)Vn?0`PWs8W z-A#{*m2fxxg}(GC|EULBZp17pK*PcNZ~@w^X$!sZb*Kzjj6g-pFbWzW;i`%}loYS~ zy7=y3ylnpH;`n7d^5?>=HQ#n0_Orwv>rr!cvD5Kr|9qeO?MS)`rXa$m%~)FxG`R)x zx&Va=;MXpyz8ZeNc15bt+n1nYOlzHZ$Y8`VMhao`zi5jpooV%H4TKT;tjmxQLkNzF}UaFF(m>*lwRPp z9mt%AV}wRWD}zgEc$gAwaKb%qVrL3uN1brnWQ1dX1O2DB`^nBe*+B^0=*U>JZR&!O zbj|55jeJiU#^K#mL}Hv$I= zFdP*yo7c%QVHTSbLxq?P%#|D)Qcg#$&|^YZe-mZL$0>Xgg=pB7eOipMh=U_!Z2}rh zghsjwFD$Nr>63vOG+b6Gq)*NLc@&)~2uu(zl-dfCj7MEf=x8JQ=BL;MULaQ%8>K=n z)**dONV*B|)uFQx`H2Eduq>CZg8bV6H+}qK1WeXX2-Bh5ct;l70XH7f=K)e)iqJVM zSZ+r~*<(jEm<|LwQd-zxDNNI$MD}mfZ0OEl-1u(HPBD7J@9JgNf)v`S3Zd-UT1>Cv z_G&9C5lewIZBO7vOiJoxB<*v@~Xf z6aCeNUcsi`$P?Rm?)ZV3#T))Uu_zY%pt7?)L>Kv|`hEr?%)O1y4-&E>@LOD z2?1fJ@#=jUp}P%+TF;Ts+DTOB=5LJd6SU>?#n&FIb{m!wvh59M{+=HDy5gvc9tCcf zkno$|itV&#zm{LyrQCHjx2eM9c0#p#!zs6Bor0Z9IB#FCOi(<`@H8q3hY${PiNsc` zh5o>DTtHOteACtear(S+6-F)u6bSS>1tt~3aCK^~5tv>BD6D&zD$z+h(1&V~3dUWf~YO$>uGdL_hDb9j~fkW!@`yj-{RgOFDz#3OY}ycBPCXh<`=YR zPKJzDq3`A)ml^N)>QTBUNRHrV_7CXs|8{?q=CfF_H(H01A%N4}s5hT|U-<)l!SOjB zIA6WrV-)7s?YQC|%!|ty)(K1$Ej85|y;_JnsRQo4bWIno4B=tA!p~&!@K0|6rGlDL zVQrG^?1Xcw`0rB}Z_bYv1eaOIv%GQGKcnCC@jDexj$KRpwG&yy$MP)v8b5jpgY=CH zz2<)>+aV5U1WEk0Z2kSBDA5&cXlx69H1|+}x8~6zDkfT-*3<6@Qf>16aq^j^I zO~6_juttab_g1CDkNA6vrtDhcpq1FmgMo*zF4DNcrFB3s>Dc!&{BQdI10Ps zXh(<%vmW7FGz5$UN-s?UE6!q5_1@n#LR{7PNImA_=fZp0a}vaRO-(B=GOyB-5|cn? zg%sRn=A0w1Ux;k>@o{Vawv2Tl!-x z2YG2b_@B+r*O>0f`^6=}n+`OGt?iK|eCWBx!gDLOAFw_X`P4vNt zmxkxl(Hx?XWk&UUYAF3j%&oQwRcHqmzmcJ(x6Zo3GiD|0^QIwrwR%6;i3ACF(W zObBGH*JWAvzCc?vY9HU@;uK5;utEnGo|_fby5||qxaOm(TtdQY{qctZ{3%J(qWy-A z4A0$44+m9M}>3yW8$r{T8;S zt-H8OU$MVtj^Cp#tXWtLLjn5F; zm(I&aYCF&Qg#YfHUFkvWeW=yeUTZTQrqH@gF;xN6_O3*~jv>#w&;td}212`KTWfsx zHpfu>W^7*%ggw)H?vA1qp%&NnSVTl_hecd(d*&i>h@!blD9a}qJ;GOz=x5y=)c(!) zf4_^t{wVG=|LXe3?V=6KW1h{E?H%;o|0mlc>Zz-3hfS@0#nvvPM^N1@_jt>Ox*%GF z&J90>TY`Y7l?7_o$GU1KqYAT7jwC8{)gfoP2-|e2OubIy2`BwiD$ooa&(n=ri<=8D zrZBSX)xHKZA;4ulquT^5pe10ED$%5IrG|^}C1&nM;q%IH4=md9B7=B~freRXnA%l7 zG~)9{iAx_huW2%-ZSn8Y_kJTp)(`?D6lj~oB&D%{Q&M#p* z*>o5EkU7oxOiIYn{4O^yqIiU@GX-k!nw*5umXXa30FV}O)SdTk}B6cAz?v- z>d{_fGjmpf_iqEDsibhR@BnNlwjhR%*tnrg>@mf(A^TWAu4buP; zQrd_s%6T=Bm-D6RLWp(2RvlVgORTWYp_tETfS=TSA)Thlv+J=wd1fqpcU59aaN$ED z%8VoMHErrAA#<4i*XLiq{^0pErl%(2$@SZ1ospckWosh;JbYmBkNZfQ$5&mli?E)9 z=OUI?w*rGOCOZYi*dPpb;kqw^ zO}&H-Mm49;sPU7#Q&!EOJTIB5c6^O=kRwLu6=wl|?d#R)atq^jEvD@D>6INubnyam zb%Zf(h1V}An#Y4gLcC1$qy8m9A1k^z{octQj6|Ran;g zi@j*KGTuO#(L^cGO^|y2@mzuM#diuF_t}J4U%RQg1Wp)_*?j2+(G$p~rb=Rc-CiNB zb9)hR!?ci?`8`NJnya2KRp3P?5v|w;7`KXEZnq)1L#QfOK{EEM6`}{rlo*e338Ch+ z=qao5{h4py5SGIr>H~U|-zOff)GDUVJ`Op0-Gu;ats4#qWY)hBz~Iyo%IHb-f^A#K zJFPL?8NJ5GG{x+-#GD4+Px%i%A(ZN4cpqGF8;(3{iZV3??DDL7+UM?|7HBkj2KRq> zPu0CstM<^4W-(6&z05p^{PxL_H?rTB4fCQae|)koXYj}6$h$x9pojbe5Iks~%4s!^ zZPiOY{qkRbAsh#!VV&QtC;#~5&i-y$ucJdhrckwvw9EvAhji7fR5ehK&V6;9drpMA zd>(RGC&d8XTgXX9vCBba<-+lz{mX;_^hs5X_h|K^V;!Ts5ag3cVFZysf)2Exma;Sv zMqop2aFMhs%22)9rE0~lsuUA8THce`*Sy-GNiV2M7Gl;|)M>)1#K-)|nGTp`fQ~ik*6Hw_ZRHLuZ^)5M31!N4g z$WEyfVHQhgbzwzs8o_Ckc|pZzXh+V4HG94d@Hy#aJmjZftek-&5GOr&kOu-u6F{Kj zk_ZzjQmbb1P(h!d1hn@eD~$O1t3fu&eO12oIx5Zvva2foCRA1Zn4~;8TO`$kQ7Svs?~8X;TNJ82KvHw`pihHc+zqo^rFsRdJgDyrwVG79N*vO87 zJQ9xjPtb2bM?-B$-?llm-v`pO3SGF47T!kvt&sY0g8DxJ^gS|$LKh%W8zDT;$k=^!!h*7uqN1*& z0_U{q&iCepHZPG0?LXTSp+t?^q)``vwS7%@rw0>;Qic4D|$iO z;QxdCENHLOP=pm7W#c~TQ;(n^koEmYA8J7Xo>PJHGp1bh^GyaowjKPs2W%}x`4}vI z$_Afy4XPwz&_1qZdKG9Lk@B7YJ&!5;f;xSN%9s8_|Xg+k=GsN zE9(Seg3hc5zRFq}$s1&lj2FE(omRIcied?e8K7CDw_LT?u`m7GEpoL2h z;m%&4aax*k|WH({!JLfMr76U-|<*J&6*-v3IFI7)GgIXuPePcxOJLKMg&M zk!oUweh?bp!j;v5S z(;9CfC>$fI>g^F9!K1yM?w--mg)?)tSH{6mDayr$bZ(Hm*#kLSWD!zLm>eZq5eh)^?0zb~Im&=eZgWWt8;PKW zB)XX3<+@^`{wXTT=?ug680C#`efw$bHxTle?QuP<<_&4yT<&}`2`!L^)YYQBAkVLQ zG&U$?Wuo)?j+XGYpF= z$46W$kE`7*&}fy{g~EGoGYGDaUz-7(y$_maRsIbDUUt>!b%$NN{XXgvd8sKh9N8RNfn6$x()4X7*LjA9RSAXr{S3&j*qHDM$i1C7I^NHpXm1ON z<;|*6rUe1K10h|stu5#R1Laj4ZH4tgOqkp6mIJr9W5u?I({~SA(+>VC!p^;WMgi!H z3~vsICyt2`w9{pQc~afK;SXPHH%gKp;vIo6ySZ*J8@RJ)!d%?>Lj)E9`XQs8XLVfA zCI=R0esGBNXeR~q){`*~DSZ0E8Kf$k&e@%GgEkpDfx0YJM_QqPp%l`6b(r)vpXYV) zcynQYQ=ti)rsqe=G0C?3Yr&T(6V<6cw+e?^M8c{lX;sh)&6ec$AZ4;*bn`{=w{dQSZI`RvX2PN6o4>t7IHKu6j(4{tYzy+;L3 zpft|TD?2nk`~kQKFhQ@W!yZbz2nh6Q%rvlwR|S1QMGEFlE~-EUm4))P^IR%kdFdgV z0i?}zEZP4n36FGwkH-ZgG@*BdeSL%+WghxApkdV0id5ccoOy(rUU|gJG@oy2^nvxt zw!wR+Q1}y|M}?z-fPC~P!rSWYWdzw$81aJSTsq}+8>FcqZClu$aieF&DUKj#pYCpB zR%g&VgsVmM*Gm7hDM$y^!gXU>_}@EeLS~H_A}+5+jme2*Daas~XWz7HE(_eSE<`uSe9tKz)zE>sG=sNWl&T64Jua+O{MlI?J1?G;EFf*r>Ku30kz537ud?r)Ty2!e`zGMEPd)8Uw6sRjqB%P&LQpzmrC7LCH2v^3aE+ z1N?Y2CDL6qd>Xy@wt4QgFD2%8hlvCA>kdOVX(_+c9hOpLQ_~zQ$P>Zt0SwCxNjlcR>c(e{Nt!A-Nskp17=)XyVd5N*JqF3@lRO3CwzV1l3&j~B3&~qbu!>DcBh=! z7GuSNe{qOs=<{EB-E(!Bk7|{XW)KIZMN}a~9YVBby0T;-$FLu7S22#E_BdQ}GiW(l z^?qq=pV^^%OCMsi01+w@-klrr_wg9G0iF(?&s3(ds-a%OHe(E-;-|a5@@-rba@C9Wv#Ze&7O=APvcKg} zlmX=-yPg7LWi3DLLJS1JXlyU5Frg|zC!?HIQ^O#U9%Ksscn}7;O`xttMuhTQjcC{M zEKQ)a!beumRYB7rWw{Ld-z%R5dD)Kcz^DXnZUo5=l-2N zV2w#6$?8U+g^pve$H8Zq6fA;_Ti++cw3911Q7x~F!BVdTabiz$Lq|v-#VexO zC*yTUZF5gz_K%9LdsiqY13NAiXAVdbn*3J3;hnlD!f0tO$_MhhNX-F?RHzbuSPb0h z_UrTS@5Ut1eO5Jo(bi30Z}O7AlJs{&F^q+)74l4vCL*1as7ud*Bv+a|;~($>c1tkz zzDWhMGW9h2AjP@xjXh}NRfayiv{Kf99I0XR3RKef$ZHh#;>Ouo`<%IUziZ?UKAfw^ zHSt28jlQ>StSMkER)w7kx-x2F6_E4r^ zgqoaZjG+oEbYhyQ?Jp!=Chd)JDxa}ynA~XuhP<}8OT(4;gEgFWosBmZKin^l=k6*N zU31L7BD%);*&0N2O1NBcjUTY70pk2W1(TVY-mPIKCQV5wNeIVL&TeMc=$0&rdkrI+ z?_T9xWL$8K1LB{>^M;v_#eJ_l)zaR|hNSK_qqTEd_ZmE;6|{zJw;Zm*5C7)gdZxD5 z%OXeBFCIqJEjU)^JaEz-GP-KnCCpP-+1tn6rR;D60)T1mK^UhrLqBZ%lmuO2rbvSN zsA*@A!1m~{pjc|_)e@XjNmv(_u~OE6#rJ2uADy4%153e~mT#+1$BIjEo0I#xVx6Oe zjv7{pm8M~qk{aW^)+Plfa8BxjZ}@DVaZL5Hsxg*RHlBmKw3ZY_;I5Spzvq{vr(WknQQMnaVu=2 zjA+{t0E{kJfF_3+tnIkypyfq9g5^y0FS z^0bud-ziDI3Vr~1K1eYTX|+B1C4>*gV88s&R!BY_dnbl5vPYn%i@qqU$bI+VEBX{s zZWBAs70Ymel?4tR_u7 z%aw^}saA9-fDB3%qG$zzk??e+#|pS3=Jo@R^fLGDJGCRulTfP*268O3q8#gFRVm6` z-?cV43w*W8zZisZS|d)+tmug1DMFS3qKFS^Qbg{=_#TM;cO1qC&J1r?!t6+auSj}> z1{F4S>Ke3*$tcI_H%;#UV`{}5k5v|7_2S!QFiDjz1510lG$3hte*(EwIl_L~gmMal z#T6aIOuv72T`Y*p;@x7buTrP1baz;WN6{K#pC8=do-QAmKRqXKax#EK_qMg{ZX?!5 zHkG!vv5xu7qzRi*5#p^NrQ4;C(p=Ouc7zEs;*r6Au09Jj8Tv%9HT7F)h2up z)}DR=Qe;-KWJ#zF(x0Kr1QzE(`)I(M81shSX7+DME75Pa9OYMQ-}ziE*NXlAzbpHgy0^w0RI|+CaHvIv1<`d z0jm+|w~29&4r7x(-oMxLf*@5vaE5sf?eQsqkWc|J$`n;?+AG9RBb0PWd7*bkMG^4o zSg@bry6#W(nbQgM3jH$L?A4HV(CyzH1c>?uZWo_0;{>({i+{mj_yh$ED*&G-*wBkQ zW%Obtk}3fS!9p3O%nn}8Gvk-oAU^LsES8z7M%$6pN(4kwrJ?Se<>u$Zji{A+Yy?6? z&D%Ji`iAFV3V%esdU7a0TJBKlbBhYgxrQ$YV6hTK&>)0bEw=DtzZ-1hwoFK_>YF=IR7HJ(Fg zcER}FK6H2PCuhK+$pAFC1mm<@y9kAE4OtI7(S59|8jaX3UP$-380N;$doMM84%2lRz!Fl*I<+>|UIag^Uv z`zLi#N@ohb8&qxZwmrlxlcF8qc=waHiaT}7AgbM7%^%xpi8qKCsWy1#9ZTL33o#q5 zU}Q*Fs%%#DRU8HgeZ#|_PZHAxtx-JHE6UoiVa_{6-5OFZ^MS4=Dlw=o(b^SXBEWdp zZT&gn%uP!9%%~5HF7;Ec>~66ELP9GpK({}(T;Uyn&!VZDM*MS+Fn0S(rr&?oK157-PJ!K|v9Dca9RfH1n9P^{K^X-n^JhXFmuBW}@?T zLoR+KG2lPj>+gJAeQm|KCP1F!P^1z$!lc%)84;sQrOu5}X}tKNq1f4N#3)JgkoT`VH_(vgqi*r!e(qGiJR0AQ|$9>Cg1Vk^4 z-g(Zqe#d{{3D>j<$DJS~3r*#=O^Sefx6GvG5w)wFhvF(HC(h!kxC-Hp`VrBmWxOA# zG_ymS4N@LORK1-PUi$N%m9fmYg)7q%cG?EFmJ8JEPjgtJTrYZ2sKod=YQfV>p@mzl z+(-W+aG~)SFdmsA5$V^d7|t>5;gS1yU%4;?YM z0RLZOGXW-+H~YeUBb7{c|5Zi>-OUx{{pk8J9?G*eQ!i}vN>U$nA!4|wU_l*DtmivH z9bE~XUlKZd6mW9l9G@^ZhHE%+kCh|>O71!KlYYdQ*?_qOFjIO=-QfgUMIv=p_*E=8 zzq5|rcFg&U&^Zjr_=H$;f0AUGQ&Gb};?SLQ7FW0&3S%%cXN5zyhg9Bm7VJ)ziRRHn zFpttJ31X)YbClv&Wo`7ROy&^&`y{b13F1w*@ukhF4dP)N$mav2LgZ+=nOzKWMv?ob zK-SA+F)PE zfDPQ_=t?LD$>f=x^9T$kn5nq3mJMcxfJD4Q?ycjQ zb7hpDFSa-h-CE~$%@J{BE^l=@Z5>5>&+!bkWb6=5C8Sqi^8w@(cpxkqq-JX=B#v+S zA*fQdG=sK5|A69yBu^r!{&8<+g$xO}dGhF+aYRccqQ+vPkEvh`CDVRrcO)`*zQx6prcC0cQG@4koX{KVY`l!t)du^6p4yhsUqRwDU*y` zVJerKXi=-kW2=FrxYh|G_J>X6Zyy%`VD*@(NM)t+K$gGAk%S7JLGFv@cik5AZeqBEbVg>=-8;m&1{<-wv1A+WQFvy4XWj$7Y+E zsvD#uiEMq+ceKaMsCLRU$WB6XXKm#X$b8jQU^qyXr!j^wRcFl0mLTavbxa#FB(*Z0bjL}!xd1+-g(Yef~>s{+ZzmgEc(lJAGSAZeoKAz^x?WxCrU zgINBWs=Fd*+x!c9QcT#Z;m*06zCJCYba8_ITJhf=@KU^|1R23V9Ix38wk_L4-o>MN z1{ad%KvMJVd1r3j$pB%dCdXgI*2m&QFu5RDU5>m_Zf4P9u|1vdm{)A$vDva-_$)FT zjU>w~86jrT@GZIm`SYVw^fnkAh18z|?|8)trTWnZh9Ng!p2r;`2r5rDfCRHQeJn2AIG(NKTJrkx+ z=~TQFQI>?H3ZQfpljDJkl0%Wk~1!Nz1PPx-ajxF+l9g*KiBK zJ1bF)%V6+X5o>y^>Fa`TDhhnI`nOp?9`6xn7J2wgt-H3pEhkeyR9ngtdzcBw}7H4;&R#_z35u z>I{uFr;mag5YDC9oBe}C-`QTDwXK1BBNtu25fE^lIYDwlm>Fj}*v(mUyDhGqwz`^9 zZ_acbT#OJNo#MUV5nEBW^i>i&TO_s%1H6zUx!{{G+!Be`|G^omWTdG`ozlQ%6a&njKv@VmyKb_ zTi0%l)%RYRDZn`SBk}&o$x*UXvB=RRV)Kqrtl78CD;d(_cVXt^ao|pJ@{(O^jIWdP z6R;%;V@^vP)h;wg#AoIsHdG)+xYfXzDzmPRYG3mXn8N|&HDm>XI(w9CG0o08#Cl8t2h^3CJl=7=3 zA(=!N&&#Aze2%gzG1R1L=F8ix>6+F`6YUMm_1Zv2nhT6uHxG3Ha^arB2b$m)T3e>Nxe4Y_~QiLY;C81qI*=MTc)D3_)+NhneN)!=Fdc@5U>kr zW?4kc=8mnmMGV0?27qj}kY?s?A7tebzVpA{cj7MTvx$y%_&Q5u6t2sJxznS3?(78V zMC!-EkB&3v*q`VD3E;b|I?A~$XOa)RDXxEEB5e{m2|=f_eJq>k&8YdnDbbS)iQ+N* zEssz9F(x<3%t#W^W=(NRw??e;lpKbn*YH zNbxSGQ*;aI!tQc=_NwC7S;cM629t9o(#2$=t3Yn0C{udpSDxT^g$k=s$tInWdFJP) zX}WIk)i`pFZdK8Fzmvxic{`A62VEG}+Xz{)yj5D$HPq%W04v@~Ah<#8*l?ut_4^=+khD+6(c_o_9h(7`RcD|YXMnx_F zw^RrVCJ}gdEa9z+1ZV;6LZ*v9bv3x9dFnOt)I6BN;jxIUM0oRP93@DuT>9$mK^y|y z#O$ZuHqp*-$`F`~dv8;jBEoVr+h;w!;PmNs6Q$Tp!h@R@NE@?N^1`b{0F_+PK~^vp z-BB?O2&R3N(_|rY_dV(u$ozze{MAhH0G%rh%n)|{ILLJYil)0cV_W9!pTj&WHdAfj zw77s{3RrvQ{HtxR&qjd3SGJ39Zr*y%%sH!Q#0}7sf&%Nvv9#La#{p@lVsHtABQaLUxkV_;uH)%_k1*Za$Dc)7`N) zeTL37X2nz*7{#KLhs>hp{-W;%~!-yFy7zuoF_`{21NZ1zsF z;^438$Tf%n)ILUwOU+KFn+}gM7&8c%vqYkHC zrNT(o5alY^(?!Xv{{7543UgJy;VKSqo?};2DpcDo`~M!i=Je^BV^K|h^TV&)u7b*g zI15%bt8&f;^Eq+bS}5!l_u13n;=fgV8uE(Ep)JK&GIRa2^l9bWU!dug$=}}&OhMwM zzc-E}{-d~tt-nf<8{Sr!)_TTSJ{Xd4^l`ss$i}EFUxcgucuQ zIiuG$V8B{Y?+C)VHp!gt{HLRkyl)a@q>H`}$P4bek>3Yu2e3|~+C}YuvUr&re-ZWR z{tgZ`&zJn?J83)15bo}?xl_pV08s8O!#tgWCrkW%)67ml(&VB$!G2s)JC$F?5=hl- zkfm}z-Ggum4usC^VFO5ENJ)2oJ#Il--(+pmOMFY%%0X!t1DDttTG3`PGXM)Nv`}TJ zMd#h1hHls7r7(K<2YpvQ2&=s}60d6s>x_CO@~Wk_Cp9J|q@Ji@b;$;%b-`&pH^WML z{Lg+kp6}cP>zlPr{JEQ*-m3EWCLpAk>MGgi&vRlXlB;BEesxLWq4#} z){aMpcA=eHZHq@`=VG~K$L^w8*|9Ah+uCmV{qg%7|KZ`|F`wc6dA*-6LuR#mo|Ow1 zJRDQVTlPyg1&Tk=O$-ZSN#8ePglom*P%xy}R3m(YlX$#(n&2`{ct|b|Tq7KG(+$?N?Ln z!&`#4)lXh$x+koEGSxp=$#m(zs1vqNZf3eHZjElQ{0)(~L!8oSC4{dttE$;zAsvD@ zA`=?tnh!X&bGSrQWgTU1JIl$Jx}C6&)cFF8?bEQ@DN@p}rf-*sE1YJ^hOIPaLohUq1`{>sty+HsO6cPwkSYu+?=-lNQ9Syhi39!KbYT7cq%DJE*119d9j znbK19mT#jhn>k0*JlW{kr&DRHHIGG*ed;Yk6+N8nUO`Wl+JNlUV_DU{?{hcRIG1R% zs+{;kA}dbt?rdqme6qNPzM7m(VwaeaZl~fYw<_n0qkN6acJmt&JL#Z!r_<#FK#yy( z=e?=`yPX3yM?N-3+Z@s&+5Zm)OTwv(!M29N4hCL8SA zd_alTguIn?RQs5|^PusSxEPuxlUr^XGxx@5XT)R&u2uyW=Ijn%9x*iJRZmuI_& z@tYNv>}^6`ww|LDX4gcdD{|UHM)Of@%~L~1wXnSurEaH#VLN&E5%=9bSgCkb&0Uj^ z4J>WysB(c13{u=iu?@W~$u$7kyG;7DT3GA&<`;*AZ^b$Enp8I%PQF>RQ?uPT&?Wq+ znF#z8_1L%i@ZGm<57ubsp z;%a}L*`<7>P1_awS~rsjF$sc3q5lkiTBr<~lTafiqItDHCL`zxK>$yy2{Uot`U z)0o=M!57Eet3B_|illhfE}5zRUr7hPTv#NZJ)58>(c1Ct1qj$OrCfh|LPX7+=fhCq z7OLMPwzUjsb{PWjr2RETW;_jhrF6pegPZm0B}6Id6=!RM4`PW=<7CmKHOn;8phwbW30bh!2`0jj9gNb(Rv$frJB^csY z3;XR?QxE7bunoJYOA(YQ_(TX0i6~*smb~n?@}rCJT>Y^D=li%Sq0x<6dLGPUF)P0N zZJdTW#=w&%Xo2}AeIq#rrnObAxjRnTD%tbPX^M!o>j67Gh&XUNkXs;8J9H@s zVQ-RXv{TQ6Usv}we}u_}vRkZHINsr^+p1K_0J|6oan)f3h-A_1KE3(`LDiP$Q|)XT zCcromgyf@iotXgVTTvF^RCEWO<99o-P)VLQAo1 zkl7;mk&C45L-ZZAdILzku!j%eZ?jG)FT%yKP^?Z&wl?A&N86|&W~I{S^K5Pzks60Z z&$=Kd3W`-0mG(~IuBRxY9D$x5FxeDK-DohG^2h#fHaBVsbKW#rNgZn({=Xz1J0Xlh zT;Js@%PE;{)azTmp{_<2ewaoU2IZmD0tLZat6K#2&DIxvIO1)VLzGAO8eHCCx(p`2 z{;CuiDi3pfvPcfoc$z|S+UcV*V-GNWu(u9!ODZ7ntH3GMq#48pFXMM zJw;Xvgv+Q%u_I%ump1%mHZyl<_-h~)+7pjwt%mWI8hal*Y6k7KA~A&A4@N1`(1X%2 zXI#8raMog`FgxN44ZqbuMg5+(w7{cgXboe4)~Z$Sn?Fuo2#Zg*VzHJb=xK6CNUG~v^XER50sQV* z|FPsVYTEQCGRt(9W|`0YtVW1m>;b(t^d`wuA4`2D`b*8oSvXFj1ie=g)6JXfzrs{o z++!8(${(=M?Op1306CC>PrieTjWVd6ov-M;3lcMqP@k|1Obx4*B_|SccEFTODwsM& zV|8|g(sRoVKsnBHRGE?*-L!GdopOZS1(Pe~;^y7(^POQcT@<Y z>qvYEx9()Yn@z9>jo=peuRCxZEhQO*2fC@`*QazGR8(;3BcBi|!fa~nDf zNS;vUjOEz$kr;WaeRm{yeUe(Gno9DkLZ(cCQ%Wi~tMoGFzKmz#{3wqs<|GcflB|V~ zFfVIM-@JPJmdJmP=xAP2?4Z-U8G10-SIMZ&CA}*m~2-R`Y^D)4s7aRu<#43A1W6QKpp7 z>kXtxc}y2VxvGZ{WdHQ7c|~yCrQC(Z9aQHHvZu(+7|GF}=oR}T zu5)Rt@-5C2vdq@*xcijgj*w0sX8xi*WJ*<7aLA2SifePuW3;mvO^2gfqSx7Pl=Zm8 zNI9RT=;#!q_3Fm1|P^aP%^S zT%t?JM+&zi8tI##x_m+;6Sv=3{T!%Wn=s&|FpwNyS(g-2#S4k2#a5W1<}5n7v5Mr? z=bb;0*7)LqW#Q${6`N!5o6Z|hq=mLL*krC2g5$I|)Nf9&rTvakXznMAky*d-*3`mz zf#jC^CeBmCvt@()=W=vAOl+0=FUx_$6Gt<%58s;b{q7n4>y6(Rwiiq!5SJmAomPrp zc7Z;d0FSBE}*EPw}3aET~z)%-J+r zNg;80H1~~{RXawRlFuzxQtqpJBuU)L?1!sR=(3cG!lApwg>R2!YAqiDbGc75p2%rf zES2PDwPK=4Pi;9?&6>d44!Qeai1yQ~EA;5RDW#R(1f4cf%|nY$ zPf+?`D_I4kGeH|oxbKTkeL1!0e=)UE!e`|j@yt_dU(Bf3u<_0ba8v%zT9e;&1Em=J zY+m#vUrAH0hZF`VKaVO8G!NxcR~y!m^weQ=Cd!moRvf3>%hY50fmWL zjs(7vhj6i<6xhx8zFSW(LOO7SRl}FZUPI+_@^+JTi6-BYy+q!OVuBR0jR#V-)J3xVX%Y2w`n%($HNk8W>cZCV?sd{iNQKI(IWat5CL zKzT5?K)_n$*@-)fc+lS8T-Sx0_jIO)QbYB>)6iME*CX=?bmF?#X~V#$g@-MMZF-Pr zK6ab~;$5Suno0M|>vMp(r$s#Cb0s%T&S2+$u$UN#5>z6=#ee(8h(mp|*Y_q^^&ymF zl_5K7si(0OFH!d$^6rAk2nB29iA1PTPFsJ|KmF zh-us@Y&855jw96IgrzU2mz1AtqN!7jpT!8}1dgCFSnxhbk{CZiOZhfuUK8w>&`wT> zX=%hY7YuC;Tuqtrp@89711h(4O&~E^ z*wS2G37RlLCopa$lDZ(2 ze8L1mdJ%_3P7MHD*KQJPs{TldQ6NjXSD{Dp%w?9Xa9AO%ewb;72yfmy*37K2AfNS_ zt4DF9yRjB2_5^?QFnI&F$szBU!E~voOo!02JE6DPHbo|M3+|6Y6|^VVRg zV5471Ew6Tudb;KMUzIVP0apbOd4xheO?jR@>!z|y)aA7uiRnNq?s`eODkqP5V!?TS zMWZEs2FjG-aH}n)TS;Ap!Ai{fIZS->=gL|;>M{fMnEXxUAf@(hxAZ|<>AWKPb&CGR zIX_-kk}N=%&=bC{um!d`-5$I>DI~nePZ69Dp>0EzY1)C1>eNF%ZjF|(e3#tH%Z=Vu z@rEX#@JslcZ=UqMkIGx z`exuJrL-VoX&kjk8N0qCqY>DBa4h}R(=9*b$D2B7aDS0$X*YqjWmro3eCD5U!^UN2 znxDfJn>1e1n3+-2+y~K%+TR9Edw{vzQ|}Uqr9$ZJQ$ZW{pMEu5znvuX`k8d4FY2FA zzn$w_7DM4Mpnv}#U%AUqzAI7=M9PWXH^G^53_5{8$DTLM3F+Rtm3AHhPMD~T_GCXL zHBy-+sICf%v>DW8j`{wauKelE;-=^Fu4TBxh{;3h2?lG@OAun{K5mHjDtWdIzs>Ft zGJEa>Wz_T{^wBbn7xxqIWV?w>eYC+RyX!f&FseO|EvEHg1%H{0{Hh`%g<8U+O=yEo z`&sVc(`y9KYW_%V)8;N5CC%&F^6JOC+!rF*)Jcreg>)4ZJ&~rqAUMQ@UeDN}pzq zwhrOttcx4TKX9oOAr|bOId)ClDFe4WLztajv=upl{4CZhi>8{j?I>(DC2zmC(W<6r zESR*qee|-MPuYw0kJ58;8>}cAoE0~VA~t~nn9xTe2o^$f-pW2X6UL5^Cr{J*%&d;~6nr`5(f!Z9dXVaK z$0{n3nQ##vyeB@StS>W>ci`5eL+6~Ig3<>W`BxTxtEBWG(3h^kf?TiQsxi(}4>>O>XVZ%7+(Eo=^h;a**%JW9>zr^b#l-nnAl3Tf4^*v z^zED5r~!!VOIeSOc{y~tcC0Y2|EUWHVZP)V!P=1t+IE;0x^ebv6Xg}goHDV_!uIcR z+sEB+Y<=f(|3}%OcFU#W^|b~r@9~(Z+>U~_;_a=(`Vi3j~WjmMCXMa zl2Dk?jXIw~z9>taXSK&~3;JdFJ$S8=!)qSZOc1K6u(0fnMD;XI%T0J*+#eJNEs1O} zPj;)<*fjc;o;TJ_L?3Msrk(!M%=E6l&6g@aPSp*@(fW8@INO1<8kfk@H-7y9Fw_*n z*Rr#K7|L#ztck^IC)~2(uDVcKclt{0K7PlPYoxA*(o<}D7JrkSeZmA!AmvSW0 zK5R>ma9h#ooen5}|jCp%7=QEj)}XjUudv1QRV zoJp;CHC$$4gk8D)(U8Rv%_H*O$opyYP4FO{)2VP{O3g|&v|y5P#X97ce$aVNf|6ik ztJ@DSU6qEx=J`_wHNidxiME?{%K}G=bKm7tZFBnv?S(OdYwRV2-YX896}AM<7V8u> zOT9z8)S=L)j+h&+6IHM&&*gZU(0V4|DHf#B4DJ%*#>*XwZsB5jA_?Qa)~3IOiFT>y zA^+%Qx_}CLq;*J@LwN4tXm-eY+iUE?p%FDbZIY4BS|n^+Xt(ZI-BngmjBQ0_p)jDr zS^81=fR(=1Etd7?otrU?pOWh;SgAw44TktHk4`$Q8}41`7L*)7v@6gBR4v(zYglVeD#hPwX6Tf(f*Cu?@L)#WlMXq&RR4a3xxB2OXrH1`4bg)MDWIpc z=X>TzbxJYUyUhqg3C*fEF+?9aqvk_g@V@OX4%QwCy|9zpNLW zNWZ<)axWFlZpWR`X^*-mjENjl_YzM|t*0*dfH0H{;w_2DMr1^1>^?Q3=jOY$OT)YB6>sgqh@Ofz z@@zEv>jh)QTzMO+Sk&r-043UX)wz zXDEFsM-qNLN8l8z;7s#??Ua&u^4G;Siyw{K&lDi))5O^M!RPPOa`jHGxdw`?T|}CW zS@?B`Ob(wqVV8lSoLdX6mfTEXrfc`pjmS6Y19`N`QI_+uYH8>+!U)v+$zEwmQDY<4 z+XV^UIx-~rS9ZG?JLEZqP|}2AYQh-E*kJ;g!z0cXuF9$mAr?=IND3|BGHSwEZt7Lx zFl-%(r#MdK)GF)2!0z}Mj^4eY=h1I;wrXdT#)8*lUDU?fc3@OuO^K8h{(kp-}jj^{b$RDNd6Nf*Udp`GEt~wM=6E zm2vXoLfMbz!zX*SiG@xd@thVT!HjELyG}D;yG#Lj3Eb8tDO8qRCF-EV>5DK)bdi$a zrFr~F>~@kFsdYTGh@T(!E$ZB=sN$d3tZ8~Envka6v5W23vGD`R!x`$Pkn1oY#;=Gn zsUA>7Bby5!{^aIyPWIT2TBIX*Ce4CM`B}%_wFBxREk18X>A$HTZ#6`2*=L_ck;uf} z`Kx9K3t<@^VAKVF~v;A?F=XP!REDciK{-B?> z(iGzjUt~TG8FvTQwk-KUUmiZppuD~(-g*{WI`>%9<`sGoX!Y9PrBfNX{d&K3dJ8Jp zUE$A^jj!&~gA9Z;!J9@-J&#eZ&d3L52Uk=5aPHpHnAI{=e^3*tTkSu5LTbL(AqeuE zMcgf&vOP(Q!a(m}N_6m)r5jyY9-@K?k%oJ;`2rPd6r&v1;P%c*0I$7(DfCastU4S@ zQ&K*>xT_+OU*az)ls%GRWzt&4?`*<1L}%@vgpHdlYnH3_f9V!bT*i*faGxmC7wc3u zC**{e`E=_T;{dY{wv7~ULd}cCW^lO_Tw+SzW`Z|1m02LjYpC}?IT-UC;R!&^zOo&C zV36F;$tSK5pR@$=nJ)z;MnW1=uxB*xSNAj;Ye_2Xq%C25>I736@H#VLt8(7eU|}If zOoN?7TDJt8e}NglSpgN|;-vy8QkAw;K@96y%lqK+aUf)wmay;{DOa_6ODMEZ6X<)G z6#NYAX$r~$vUv?46rixTB`!6kEyI8~96>)0#=*gU#0?6I(~Mum$ALcf8ycDfQBsEx zGr`}8cOLY;6-P{Egok0I#y+P=ZBU^Snw%+GDf9FeDFu!ot3Vr29J+*qBvbO{=8S-R zM=zdqt^;n?+B3GE+p>#x!+57%bpfJVMTxADkJ) zljD0m2qGL44iRn?4b@uWZ7zDo6{)i@+qs9G_gTjs19u7W_LFAKXzr zGD{q;oNMi{A<9fR-ka4h5ba=y3x29dwj9IyvTyF<`35*r0DfOw_4^XMqbWxy#s8A+ z&Bv08(%0Dg5W)d_w4aCwmTf}_-c+t1P7&ms2+^g=SEDx=Yt{1l&64YSzj)#6LaFUV7S?K%UQ5U zzWAT}bJoEwjvn89W9uX}MsLBkrAxUYAy}X#qy-SuO^#j*U4wBB&x$tb-Jr!jmO4UD z!BQ=}IVL$-UzA(dx&dFa`yQAiZ*euP`-LfuGR*CilqVYq;YQ$Jz~73xY7QJ+pnwW` z@je)^2nQqyyrKj_vJvoq?L^zZ2^R*yMHn$f&Wci|6w0MiDY#1|`01{L12cmElpr$W z8`F|;KEdC%lQIm1WGNVh1ELgYy4`G^!C~urd>W(l-#pX-sR_6z+^{!DrMBycp-hRO zy$WKp!|IlCm4gA0*Mo_b+s<%z%S=T7!%(80@G_;i)39eb)|Ox-7Agr{Sqruy#8w1q zu%u_;t?B0C4dC zmgDM*rNm+ZwE-cn!Yabw5fC{i4g;2raaq}RQO2mPgq<7){#Ive2pGghzz8FNV3HN_ z?w9_Ry9n%RXQlOah^y}ZP;(CqXUYZX$0!rvhP{Lpha31GzY*!zEfBLLyJ3>e=$QtG7i4>7xiDWG^IQ>^reJna#u*l#Q%Hd^KoCgNrb5h(p;sT@S>@k{jt z`}&{)&_P{CSTRWuPO5kc{7keNztt=WnC+bb&7XltR5dBms+Rhd$Ne3r|N=W zytU=q;|RD%vF8D-$PA{Z%pI5{pz<|iEq3Bw&qzI?#7Hm&I~JIb4t|pVX1^kQP4z6o zBG_TE34D*vdxz$4JnJRATlw{&Q!w-BKiTy`1K(g0&q-kCeBW7}vP8WttLPyC93iX~ z1QXMkd_CSov}MDKTooM%!)ZAZBGQu*4B(;c(_Ix8q?%P4lX*4@CZxj91+>tqAz|hCg}qX) z?T49i;I6l9FMz4Q`PdGz880^B*H#@#$AQaW3a*~5JwtF(5>9w7*xdrHGY~T5(C$}J zPotrAo^lJTx*Z2~7y|bxD4)*~isg{>TX2P*2)_&oRuPsW#LfL+ijr`=o+v>GpFSA2 z=`)W{1nnujc)eOAF>>wtPDR3(SbYwWlMW7W&x3!#-4ozF$5PRV$`!M%+lqI9cOD4v zv{tol?PMiXGRG*1&vKTe)vsW}-XT(=^A>j?p`X>cjtc1tdRlcE;DS zO%U`ET&sw4@qz4N0@EyvFn+JQw@Z9I8ont0%a>4Ib@-YCZ5%>u{Id7MJz>>|WKf9#*?MSo#D7*|p!2SV>xwXqB!csn;{io?Q!13txcb=3e$}&r62~3NC+szu zYm(0d9M>#TW7nTCXERfh!=!j0HMkjfDME{%g$bMtAv<0Z`k%w=Rp2Rnjk$@ySYxU3 zV@Pl)K}G19ArU4$ueolkUcO2#5{j{nmbn|g5M=G;OoaHNDG zyy8++H}6vR+#k=m76RE>BVv53fyY%~mCB$OG!3Ef6au>KG zXKk^bxS3*SmRt>Jb9iQZGH4|AuQt036F1!2?!bn*FZxbN;B~Qg#~SN6pCCJzBCr?_ zhe)AR`GfC=C&A&n;+yVv(4K)$J zq4zD}TUTLk8 zHo!A+w@WgX^f};KRn5HgF3Gz(Zah6Ct!rpoA4Ai~jXW)U=vL8b2PbV09#4smwR7cl znGC%i#M*-h>ClR&(Tf`4C?z2*>`L4or}(aqa=v|e4Lu@Kuz!4fqN;U)KaB5>+1bS1 za>y?paF4fl4r{b9wYmNpkA%RFnugA+&fbeLfy#X!t)3SdZXRx%xL(IE09-o<&fciI z5ZvtZvMz=~caCh191($pJ+2`WY!Mn`rQy#{H)D@GHrsO(X3qMwp_9-adfxmpVSNly z>k?Gj*{m9(q|=3|I_*R573g8g+bP-8FN1ch`G;^1C(5BcFOo_l|E%+<*uxW*CXo9; zXt(u(9k)skO$=A*P%TXKDb}__^of5p5TiztUB#)jytld5%Ed(IVhlXxjOhAF^dfAk zrn^9bR?@qS0}y>GJiF34(1?fNTlY3qO?ReQyv++WPoddb9ln}g&@8H)+#!vwIWH)+ z4rXuHsw*uC1h?vI?Z);hy6j(=2(Q*8?BuWN_3-z^#@y-r#_77b0Y`?>oi6Bbb}c6# zbEjS{IX`eDB;6u%`Ur8^m&)bK2194+BB$|HftZrN(=S!Gsmdc&u%{BPy^+mh7hwR| z&1}k|aZ7O7Xq@F5P3=}!{9V4~be>>5fZX_s#>eoSW*p>}I@U(I%N@;<_#n78iJd2I zyoOU^ex1i?m7|mw#io9_PiZs0+!aI4LErxJ6h9s^fpzrp4`FT;x_szvwMVe=g;h^bh& zF-u6hkGVZ%-%^r-6=G^kZp7f0=M>jroZTwcT=%c=>cvL%za2(3Plh9h>hU(U`Wa{E zVoTShE67@IBMy88t!L5xvx%mw=)f)h)+J39^sbtKnqHLSy+lA za7vXPt&{=mwA4HD3yz*lM$dIk5W2@;@?zNodZfJ9$xMcl$1)Bo7{@RF%9Cwji(RVi zkA+#QQh7+Z!yH}v%HxSt(R96YITNBM*^sNp`|W;@DL)+WC%9s$pOI?9)6X*_!gR=m zBAmm+vYV-8#~kfn8J(|c3DjV{O9pOKPwylkuMtL1ue17hw|ls>fjuE`uFua(Suh5$I;9orND)nFMqFMPPlh0>{AC}k{(h5lChi@f zjE@Zzc4DNi(?dn6ikkZ+#HO%Z8>*$XYKdb|72F9E#pHYCS_OxzLm)SDzg>AT!Rfpr zGSxU>--yijifN10eo)#?2(;EK<`RQT4?yXb3F9;ZS>lSU7H#h_o?Cf5GN5kkVJxhq zQTB}Bzem8Z`)s%9aG7h|?z8F-NtUO}l)*a3y8>5JN%?!5_#`%T)L7~hXffdJS0I<$ zH*A!OV9v4+J?gJp=7eu4en9int6gU%t4IrweC$1~{ zucbK-u%0V&BLZf6-LrH0>0cfMzxs?(K7AI^+LeJrH6G4HC1mFzpK82D;cMy7^7YI zRYPIt@4>F+x~=qUgda{+(umu3HKe8WdJW^<(s4`gy~$;C#<*Dxwj83rpPl*xS+zak zNvg|Y6nwR|yP(jq@wRHy;kHwySw{in=!D@ z^meBE=Fzs5^=1KM(o&pYNFq$2Z!f>ZZZLYBA%6PF1D$wE`o$F+9(m02b`DI4elNO| z?&c;004)wyVz8MyqYx7uXNB8e(jB$q52>+mSHc&A&_fz^O-Fk}y3& zFVGU{gpd+NS3eWM2)$bvHiNE{J^SzP zkIu@|-yWojJwxLCmOFhz-fRFfYLe>0^8fjPIXSE3*`S@9u#R^+yuL6|qI{A#>xaPfAaoFbab#ehee{gn)Be~p8G zGsw^$@geKNEnPY?bp+Bb)S@?*dV1%##`d&^X@R(lYW~#GUx$IH4)DZgjrSP9muYrW zP%kaOnVOy0?YyUr0XjFaG(g4qgJ-)P;pW;pM$rChKz>iP+X%qz#93LPY1J_-s zxh5Pu+zxbX)4G}&?%g~2sep1bFs~DTW|`Vs3JeI*Z<%v^WB^y^d>%j<-Kz79Ajg-y z*_dV9*V+pM{Liy}F?EFWP=w^D#nJlTXB*QvUOcM${T=>IzC%FazORaXTty$NS&gef ze>tEM4mWsyru25;1v(372nc-8Jc62{!}FA9se!itnaS-fc(M_U4htFpGaB9#Y zN$q7kN;A?Ubz*($Vkzqe<%1ekr^)$|#^sGY?&s6En9H+_AL zaZkuAie}lYC$rssgL*X$rV;NZF z<`7kY2G1}~y&u`*F}zzp%v~4li9jXBN-{~?c=wt($xnT7sf{@Uf3Z*y!QK} zvs?_g1Nzy@2voDJv;35een2~*wOMhuwHpcdn!*LWR`JhUi#t$%>7E|GHa`Wm8~49p zpfEIO(<`NQum3~kNwOQW?7yVS75orKnuVKuY~Hi!qsm%teQ_MWG_tYky){<{@TIqX z|C@IkLBn*Y?-(i^R(cPk4-#fs{^6IE`=06RSDsd`u!T9(b-RFs52u^^B2mw2z?}$s ztOVTU)?Q|7l?3&ds1y&I#s!Jz(otW1y^E=qA_|_l#n7HoeP+4!NwgCmR#(#Qw-RB!tu>Cxn)rq=p7b!@QRNp_X@|}(x{15LowRcqC;6`?^5~|}-znAMW z=MUf$K6@V&0>|0AJdjL(4Iu7Q(FKv8CH#+IclvQVjvYW3s=E#Ntw2MF85bg@UQwY( z>on3B%k?ZRWcp-sKy$m(xxXeZNY~r=NDuLS7}n`%$k>eb*r_8vj^+C9=GbZxPjfjc zurETBvNy-g{nuV5fezZ!5#uFz}n4YN8U6M5>uIXZR1dEf=6b9D@E z#=JcDQwMWm=?otq)9hO9qd&%uotx6h188g4Eo5$| z&g%WT%YRzc$as8THsmXVV^_A(5i!Rcl>g)4kNf?c_4xHW=USEi0H4A$V1TbUz_G*P zDDW?1eTwS2e*)Z>w^e>dfd0X16{XfV?K$tLZC)acI$nPiVbGg_lLAlMrA^sR=#PDS zT&52C{79?A>b#l{Ex*9&2)OjuzYg^SG^`nY-Cf-z+0B#vcU*N(t zhr2s&Je^Raw1-W(%pCa;u)5&K<*B2`oyo@oG+oX%bax8<5%Sd17Nm%|s7dP6oZ?Gk z{4w-%8K4MhzI+&sF#POrrNKW%9d2q4C!vutVBK%{s8%i5t@4nGDw>BZ^3xAvKu2N5 z54%wZJ=a;)&(p_q1rz@$Mcj~7fqbE_@HjsoEwCko>zuuW+INE?D;oCR*eSiD!Sw+u zhUPa9oF#sN>kuE)-Ne~5b{ZWQ@LuCywbA+p=3NX3nklm`4;#E}-t9f5iTM@Z9|m}r z9^DrFWT^j#YY18pjuw6koUq%Kzs-BNS~-xbnhc#Yz1Pcl z@3qLGJ(agN7_Pu-w*SPT-;4KsN%5A{4dDbG+iPE@91Xb(JRe4r;=Qe;_gCLLX`!0q zdk;bmzh{_w%y1+?i$Ij#~rWi%@Llw*<6=edW>?9sXP_BM0LN8Pc3d3kIv zs?D@p`0K1){9ADx+!7!7%)o~e;gWm%Y-2A}%&j>eaLgKT8lyPx2b{5hcUjj%nAS(^ zmj7ZvGw-Tsdi0NHfJB-PV<-&osi@9J|=E31`$ZnjCTe*M{;wTtec zJp1V(ezrvIg&fPWjri%i*{%}M_Ms|VH$&!V(F6>v->dZ=OL=^K{(s-GLdxsua8Gzx zW7htZNK@~>Q-22EJsH&Jy}VkpN17jY9&jO}dzWwQa6as7I?}1DJ9z!x>a@@{gev}v z4V`oJ$_IRr)1FROL$pV>rhfFB=JgiRI-LIc)nY`CKJwBf z?hR`SJ-sbG510*91B{I01y|;FD<(#_j6^RLUui&p_@*XE`Ia}mq*eP!bfk_YLq8MB z!iU3It$FZyb(o;mAzuYO^Iyj5KBNrGJ~2zsc!x2qBpZ|z>1s)w>&-BgBhhB7J zbGObL+9kCY0&g2*9fwzMo_VXhE96y1_@7>W@UL({l4o;4!W(8Jt?^(-eZ-F!!AVbI z7pCJK^KJh6NM0SyfX!d~lK~cNd_4ZC$!1sQH`TsED|g5C$T9RUN_EJ;cn|N%u+lW+ zHnchJDF2iV`Q2qF9MaX#6lrL+q7O=Cx}q3qWheY^&m{Xs$bC`gc#E z-7fLkq|a?JiA+v>@#u6fsk+uZA)mYLeP>ORCx6mtY6D~T&|S6cYC-?C7VC(v+`4x; z3A>$yjgdd@_q<7C+aKn>8_BU^!D;o#cuLTqopVs_$FY=WRW`>Kl z<6uoKi>I9V@}_N)T+c~uO?uh?6sos|qmn*jfpImwyzG)C6UJUd7vFunHv z$=Tm_pP~(mLnZyT+~L|VVWn=r9XGkrZ}Y_Z>wTVgSnLIzNhzDXli%?~ervwqJtccw zw;nW;Vheikg^g5X$o3i}*7Xk3Bi=9wtjUzN3YwVEMq2Omw_(Vx@S~~RA@xZBh0}

    G9JxB{eQ=ol*ct!~^v5>J*zl+?YlEY$M_BhfaQA2E~`9_hwoW;9a4 z&5nplg?CYEMPz<41{5ydRbI}D4^+jli+3d~XQrpSN$nQYg z-kSsrOT$u}4Ty?6;=<83K|ow>+^Dn-VihfI(`qgHv3!5|{gFR%&OIl8+?@A#zMjv= z6OJc6=4O+1M{vPf*9ICWlgjhxwM~EY z*YCHozunEsbnA_azb7fU57HM3jfm~COeTSynAz#iSPx;hgR&?i1z$tv8s-cYXd9Ei@ggr>}j* z-9vYMB^ahHPVFArwp4Vko*wz=hxx0Tx8Fln&cCY#83FpYH<*Hrmw!r+oz~W=H}XPm zZhbx6isaIFI3zM=V{xE*O_};iedB_=8;?*UdIbdC;~4uL|1K&PCVuq#(IY-xs^aoP z2z7VGDH+liq9>xdj*X&#j8QBVwIJVlcW3*5GBJ-bdLtfT(Mp}lKHQ5CsgYBwQONF+ zRxOlxz!xSkyI@BG0AiVWZ^b6`>FG_SeHE^6S&Z<^J>P<2;t~tMw#NIF)gtOz>4A$P zYZUt~7w2XHi!M#URyL4id5;*kN#V=wHG2e@MZ`+fc7dx0auHJg-DSbnyS^q0gfVsn zsMu2E_813iE+aHC)0e0Tz|#^w@bHiUbNk za=IC_mq5|ZBG%fr7fa|*6yG?>L1bqVw32uUJ}Fw_ykT}*A0eTL;oU2uiUrB>tQ(;t z2EPOL5@Y(zLlp-o+!e1)w{6ty@H`D^r;79};l zR#=``h#+Cp*v@~{j4r=w1!0dSA$+J%6gh<#JWhA9tI?^#xnx3C!wKiHc7!7U0~Skf zs3RY%1B4aSO?@iPBkqm}iJ1IjOk3CjUi5JkO<8JHb~AMsSFEGNEtbXD=9x~bB&+vF zO4IUcP}MiROV;lSc}o|)90>PLXxq{{+iXkmLEI&M$ojD|@&j2y+&yd9OB&ld!>SW+ z)@5hO2hVxp0^N07WM2VH$U3mt31x^5Iz#a`j!DUwH)Vu1CjT{`=&; zSnMg~rgqG*oMlTfQl5fu#ha?(Z}rhah0^v9m0RQ~@ZP-C%C9ANR+|%kRr@zWcH>)o z4J#Z72mUjr{-(qK{&wVef&CzT_uncHqIlc6x`|HtE7jh$p9;gPq54GI?!BwKwKfP&-*yAUV>^vCDL0Wab<6`Ng zQ=4J-Z5S?m&AiD<`Jjv36y$a(6!#zXx$6|;gwy!V)I`1au}#(^;Sb$o>!13@D|FDu zUZ$1NV6c5SgN5ech=|i)oI_(r-1@q^+yE_kq4xEHq{hyrQZeO;c^y;rF0FRHl6-sb zUe7sA^Nocu#Ev?>Y}ExgyyZ`#^Tn)JtnFX&?t_@;e}<=MLtCw6 zj}XfnccN+4&NeyrA66jYMouM>ory&k##o))c50Mug-Mx2s{(kW2Su)jLzP-cF2fG?$Dx*okQTt z4@4-iqC5Ow z);FIb(dG3H3&r+1!wv9 z@Em1yV`qIc{!K4Xbz#4_qoD{D<#>qiR4zNvTCmK#YNN!+83o;@B+3KXjr#1gXD#f$ z6_>SbN03#wGk@vqTEz}t!3oD~`%v2gjKJwCh|XgAtRkzst=avHJo3sTOMkC4p!W;0 zM?2}G#tBgK;uhvP*wUGk!ntc-Kx}^FKeA;R*?3C)`a|PCzc+xy=D)+;iBlxvOzW}} zWFkAg48ya=TkRH5Tywzp@##mT9{$Y4lTK7>_zpv#EB#CBQZY93>|STis*BS=5svO& zMG5wyunMgG0!+K>#jHLQ*SmC9!`P+ymKc#|>)%aAExXoAjhtp;ZDq=qQi|mXm(o!*f@!@^8XdGfOGCQp2MW(xfM&k)dn$dy7el-SzgnICT2@EnV9%!j zD^oE&(|EMNw5jbT;Z@$ecPSNhJ6CYyfqSF263pq{eP(x7P%|ixIa@2tM8}IhBeP%Uj8>fcT`vfBq_B$Xkob@2~KEZd=}w zIi%(N1?)SL>B=Q)M2~atVy{yrpxeF6hJ>lOhWoO+o!9Y$&+P3GF;$pCJy2+;aH1EC zY_1)0%0kr9--#pdVFcmL6+4$@SUXwbM(;mc`d_8}Hsrwi&xvXvtLh&6c^!RH@`V11 zouRquI+dX2#4!$tYs(WkqedUDij$-C8O0cxJCsoN_X9tz)Ml1cHqg~T%6i%29Nd16 zapi#(O3lzjgTpUBfuBoToUcENeb#byG12WU%Pt3G3GFuu@pc9I`;LgD4bSFoSGy$P zw_Sgn@T*~2NKev(LAqe2{2|=O(0N`6c z01BN1cM!;U2WJPm9nHtZ(Z$=zG0?`x&p9yM!TN>a!1Lw!aeRDS1arK6ynO<_1Nc7t zpg><T7Z@&a=EwPk$MU1%_|Yl+xkbZUNd zeBQi_g6L&A^E1;!Q;Ozi7DeZ*o}XSCR9qS<-yF5RGPEGuf1dRx=O()aXQV_Xter23 zN{AOOUJ{X35|O}%`Zs(Mq?+U0f2)D0=B=7M@nPD4j-$DaJ1yV5(3W~-Zu znwpE%O=V5ZWp(N`9ZmA4_TuLDvWDGj+B?eI_m_8dt=Xa3(%HPZxudGRqoQMf#r^{o zU0pT3`^t{&t3KIYvA4H$fA1FUu?%egT?UBwON6)`{sy}~asP)X~<&itbe|Y}Gy*uW~ z@!@ZWiL(HO$gVCji$I9!%tsl8jS#`jUKDcQ46^6*V%SqJQBjLa#<--=I2ohjv4d^Q zqF+Zkf&;TE8+y1#2n7YwLA<{?dM0ni;Fh}Yb^7T&0W%XaN!=;|8p>r-t5X|+{N z-aW~ncNgFx@*pvMiY=yprt^*^!+E=EBz1=q{!l3w#idY0^;WpQ;+-hzg~{chP! zy8jhEAiEs3F+BeG=kevg1!i%!e*g8z|J4ytZ@|+e%A-!)99whV$-%hHC62jgU3{tZ z=~ox3Z~)o=C&NVIJb?^Mq7)8yX4#NB24f@G8;9*8#qB%V_4uN7ptEwYu+%P1WcfmYJ4HjS!Wot8($l5I58%qmT3`{nb zN9qj-|3JfiFv!#g?pSkG_hy${=)@XUJg_fu*Vb`uV)xwUVhkT($cc5U@;1g$B^_KN z;Cw+k2;&(tC7o3ujll=WtQJY`S z^EHJdOmChJz|o_*M&#uR5uRoDf8r73@ar1Ph9U6(6~`VwcLR1P5Ev19HFprV&;Vd{ z3$_VPBT?0WlIZr~g#qVQeol;}H48w3he9Dns9l`OTgCHDB7-R3s- z3j4$=yKw6(o3(M3S%w`$ayHhvu0%z#_N#Lg*tlD>J{B)FbI%Zj;oV_2++tMqBcWJ( zl3-hGurjt+YK>rKtY{DZWx7a-b2P{>*zGz)WxWAe+F_>x#k&VCBE5ykDVB+C0)-5* z!Oeyb^44sZ#8hy)Jyd)tIS=xVpEg^G6?dCPe`|-H*{u3U(AHD!bxHz%XE;7H%3$RVHAa7)a#XK6WLL8VR~$s z(<>srOgGZgWhzIiKaMW51#2(%*z1}XTXT9zS!~c=*xD%6?1pOp{hp~N=UukhKQPn@ z+UjHwUa)tyFQ8L2$ z+PiS8^*=w`x)YF2xKLzjCpS@jkd;c2JIllD8DjM}i|-K}#;gtNBFs%bu`jDyj2CcE zFs0lKf(5;0wKj!V?rvcS97(ka3x8RR*7^WqLY}AJx9OEA8aV*xaB0(IuLiY5Zz4ymIxGPuLF$>$bAY zXo|x}&#TX{RxQUmHUV|f_sJ=Q5fsib5Bsec z!>z6T)#d&WlIGowZ{S0X8vw#W`)=aH#W57eD@r@A2i{*UcC!FGBKaN!zO3y)4&3K= zsn5nQ7$UOAlbki@ApGdYv+_quJDmt#H1GpBox-vw=ImT!qjcOXg2=Tfd7|<8?6ZaX zI>5r0as5O!)VO%07J+;1Jj1f^)x*Y__r$u&y51RZLnSyX3Q5kH@Y z^AAPrti58FmvrAm|LQ`KD_M(>7FxiBRU$Y~hB#*FR4>~8q!SzsK+T?Xrp^PQM7Vwk zUJl*GWEELIxe?y2SxqnYs5VMl z2n!Y8Wy}XI{dZO5onX538C$aL&U53L#cOeMYPmL9QwGXH0p#?-aD1g`$Q3$=cWO4n zd0p0S5(6=8srd(IN27D-5scFXkoUJ2;Au;wHvD&oxucH6*&kl3Y>;&uUu`@6pWN^# zLEr6KU;!P*tr@f2+clsqcE9ZoMid&B8zWx4Bq!{_e+-nBd^9{xPRKvJ3s~z#6n|fE z?CPQJ3&l$}h*r$EU@t=tUOM-y;p2B(LMMLO#bk10y^|RYi{u`#EYOMc?kPYqd<(Hh zQS#3lsm>aRKH(VO96Z&KJtD$+NsSXLhmZ1?d$2=Kh#kmgrK3?Y++15seWyEO*Loa3 zSAqo`ym8dIHO6kQ6~d@Q5ss;mxUI{9t+NC!xJRt#OO40l<^5!hACF7Tic%RzYx45VU4MY73cEi*60sJbVbq^;8c%{XRflZU$)5 z^yFoW=+c6{IQ%TfjL$5ZV>%r9_G6`P5)3tiE4YZ8A~V8bjlJV$?FnmS_$^jwo*tJa z!$n!ZG8wfDh;q=wX9le3On8YFOhu8!rkWlIaug!Y!{LCFnvQ6=!L>nc-Yeb7QU}~(=VVDsrX$LH@zW-jM!?hMk&HG-&bAu2TYQI zHd^GcmNdwqO-SjOA34aIe3)ycjX*X}z%5*;w;l`-)$kcBvbB;F(=3TDUjBC2>&zbJ zA`u~ku{{u~iICNvWG9?`;Imf?OUyn`??ae^vJb#@F_2BKpepJ7ya!&P%H{qhaHBHe z!ew8FOUtu6oWEjicYCmvCyR#g36A1;4CoA!wPx>4{Ko;mQj09IfIhOg6*`HN1*#VC zcm`)s01;%snn6hLv2C-K5G;Ua-+N|Dku06;Ph9NuMM;fP7Wk6$ahczm2Yz3@^b68O z^oBtje5zjAM%-Biyorr)tqLlTEe)!w!-X#iv;f~e)Hn!&gaB?R3z|&8ZH$0$s)~7< zbn8K(=K)Zsr`=HyrxS1%li-~hKYJd>)xndquL)&(sL2FxG2@!`cu+wpmbu*eoBIYI z{N_Dqhk~A3uzM18c@rEfTIl~zM(I?Przrgg3AmP#AJd+eug5z`iLH9v7H#8w3&42@ z;%}s9#yU`SnOxw&r}O-dH2Iq%K^qSehk<`q5Mo4dp#@rQ27Mb~h0JOW#LEOQn}Mse zh4N8+3!{k~i}y4yOfNV3ufo_k;A`Uxoo*;op{&-`|uPWs!rb1tjk!7jRITX%4wP?H6A3!=446 zhLCvy*(v?d=)gC##TB`srixNV{p#0q2I0Ap4ngnO5zoRv7wM8{ZB2+wd`ODi_aS>W zc=s4+ho$ru8F-KlhbTaI3uJ5E%?*3E3S2w)z?&JAU;#6_$8WP1=~co`+HHY`>IKIo zjlF*NZbGq6gp-UTpB=k|W~-AP{wfr#Vj;C=Si-H%`#YiWLZ@8Gyudnhg-p<7H!eH9bgSB_Qf*rrwGW7>le>RO=9dzwWK{#wnXXhb#=4B zm4%z)vzNS%OMg>@{n^GoCdMUC5pV7W@ARiVJC6)?iRKCqGprkd2&V-|Pk@57!>Xwp zqC2*@|1;Ce9|~J>1;4k^H}K#U76|G5|GdX+JwB<6`kc$Aa!F4NtYf~AOsb|6AV=VU zvkdw!8T1_syHg~ve;!6{Sddi>ZsNh|gz)EFxJ6Gj>1huQbQXZm*Tdbb<~%!mf@_-d zE>=EhD>xooI`e6@DGXFh;Yt()djY&$wANmSS11SzEgQ1*3%CYQ%YeST2U`KWd%WuJ zTVS&eU$U&^$|R8dYnTVOYSUrRrxToAmBL_bVruh6Vc*2e#zHsKp)(R3%Lnv+*cALY ztK)U#l?W@_5G)5yziC^*Ko{-uJ>KS1CGoQ&CHXIg`hZv{vJ97I0i42C`^JKM%c?d< zHduRdZrIOxQ+>rWxjsrpoy!ZpU#nmz}o4ee$H0n+K@@#|kMY%d_Ax8`Ju-hv?X8D4b;`#cd1#h855djqt z6Ad1fJz$D~P$@_}vjt{vhB_;duPm#o!dV+1_=LTzIXmb>SPceoq1E22N}L%4}f zj9fAbE;Ivy+M`8v2+9XJwWEi`F*m6gSc*!g?erfAWXSCiLIo>x2x1>3I$=NoY22BQ7c^@N+3 z-wxz$Pg$5N!>u>nXE9VQ4=Wr?^8JtG{c9WXSD#4Ycbox4kda-Njzo})s*v$H^Qw2U zxsdzCrq}E_gw$Owbzpck7@+lwS5W`uQeT;kdj&Qpr0^0wvXy~P(j&1wes(%AwFgWb z_Wf&FOy6R!iNvizy{k$}<9H~%+iClvhM3Pc1FNLlgbXBXHOK=9*j}F?10luilQOIO zdc)mLwFw{0;7LA7svt>3|G7mIBUP+@>nj6)cix_Cv0bk8J$V{z`x_p94$YOp)eK@a zmwd>=(-goW?k|mOcuwxAcgMa87>Dj3^Rs{XZJ^=j7GEd{)#AUqmDdJFo6ZCSL`<5L z3&8VNhf^3nF&W)UtkQ^%lq87QoENn2h?9?QnFFKXd}@$mNrkg?=lD=swjXqn)()Xx+jPE&M`H z+Xax^4YW^Y+8r(JEf>%X)1F!YP*?9H$o$Go{H{osrKdgN!n2WRnB%fCEIm zp@G(S%(~*x51D}o0r^iUeX|z69tuY)e1jOXenGkLiAU{y2O-fHhLpXscF1R0!p7_b z;LyKb)2GLe__@AYV_r4n;V-zPrm9@ zENgrA%8g}7elCGUQiK;LON=lRdrUNwp(sL%k2gU@p2CAjjlYbr*%BBaMZVh}xZ)AC z+CFzpF(_dm0Vco>{U=NhKdA9_wuW=Hu$|dY%!NPx1CmkT<(oylv!F2pj0V6jS>Lbg z8h)uT-jOT(U;Y@>x$0ifyf+nvXt4C1Vr{>4oH z&kX+IokV`9E!N|kr9}5x1%alef7a7JX{o;%Y~GpZZ?$wE5dZ_ESp^L=>>DeKjJ6Pd zvk;=pxE94njTtO9MeRA5ekg3s)+hKYb+GDH`Q*_iKj0#&VK zKmZNZ;T`%51}p%Cv~cF74*(M|j!X|vsOEY>h_wSM)r=M1XcowPQ%6JFibtHHJ*&}k z@Pm4y`|S^7Rt8@qNf0hlh`Rd-;knV_s8zutGM(JREef%UsX<9%Wj$#T4Z><7m9t2rH>xxY&cgcfDeOivvpG>Z>uBu$ zqqsSDjs$p%DG84lmV_z*yHQVVj@YLb7{MLR5zf!BcE6-Hmw5>-e?MwC(3g7iQ$xSj z@4)w8rL6CjwYs zAH+GV=e@i-D^+;l?DVk$TxG&Wj~x(=!NB=;jsAc`(yWcosa*_9Bs-7TQ)i=?+R)&h zsWH~s7Yt3s{~2yEM^U{VG3x$YF?4Q`-FgMsz%9sOhaqbXK&(u0V)FEVmgQXrmNB zHuia_0m8qo5vi|8mE1EIE}1;}ZN5eK;3vVvlOT!6h;4G+49HNzlJ{H{GdYKUl9}2l z3ZM(PUY8#xY&Yl$;nfE8ZjVj}wsYj%e(EL5P4};6{z`ON|NPlYzpHLJMLA?yju$fb zQNz)_iYG=dm(tyZZ(MOU7W80E9Mb?(MJx|!^#dna%a*u?;(ZAWv$9SfL#8Nw!x%&Lt*j4EuV8Q{_jpvYjJNp7*uzyvhA~ zr4LQC6J-iGGXRs}Y3ZR5?Dv;Jq5ZQ4gNhL|)}w@j2PSvo6Q%&zrgBt=2>8_o9rc7; zgx79?B)pN}J8FjO+;j$v6wE7J))38&g>Fd3FeHJa$T~OtDs(z{vCAm5+cv=wLpUb{ z-D;Niz}0h|y0#$Ywb1iN+P2WdDOqN@kmnCjObE-nsY=$^xus5 z4s*h-xI2Y2x6}s|_2(vU)Z`T!X=%O8Iuf_2&v=coX~uZXaaC701)323SSR{h$E|}l z^-WNKcEQlp@qSFs@Cnm_A#F%di|~I_KO1>r$l#cNM}c_T1qTZEdN= zcXGjrMA%*dP-X+T>Og)U&aMkJU*7|t9WX$VEUo2%-{D!gAwC$uXa#y`wpT}-J7Ar2 zpSv3Wh>Lr04*J@R5QM#Ea?)tJYllJYscnc^%`*MFJ`}V^LWrIW43}&{?Rw2VD?b@Q zt^uVt$Ax-Y4CFj6Mq<|?C!ZOfU5Au|GBVq(jmLC#l3=E$6|WuEt25*1hio@c~+ zTZZr68;)3Wtsj!fd26N+@rL zKr5v386~b010!Djt-DI3;%sSKy^W2Pka=nXj7IU^kfFXjJcd1CU0T9tMG2{kWe8(T z#r0%>q+)H>tJ5g6={gB{pv7OQH1o$bh*yjmXwLkd?=D;| ziQHswtx86AANNBBr$CAYg|ehc4yz5=!BjJtH8yD9Z?#UTJ!sBSka)@}!JuM%7K)K& zCY5U{KyWu3o=GpM!mh6F?!{32z{iVskJtM&M0O%$W(`HlCBK-7g&k|a#le609SXZ; zzr+h>DY{v0W1vr}v|;W{Bab;CLRLLIkr{+9@p|YtFY0`2k?(ETTgzcBF!NJ_>Q|Q> zc6k<;eW_!vfi`#tu4x0Vx!kQ{-_Q{(qu}hdG|Qy(U%5+by0nmsmPfE|*nc-_VdpUe zd9GCbTcbh!dm>|*g01X9>k@`^5;3sp(JG0N9A8y7I+R+O0hchK7rW`bg8FDc^B5vlwRNge`ZGwV&;QLI?m6C3!$I`^G1chu~ zoCmp;cVl@t1DIo#F8?#fESzDIqc!u{CWQzet!;@LF9dc!7$hsr7{E3|70q$Z_6&TP zhzl>D7^Ug4mZ0%fkXsjOU!HWH!S`a>f~4hZ{Z$dCK;~LAC+F+XLdv2ZM$u2B&QngH zQ@LFj|4K)uL0aEdY5?uzhlu?qG@@ey=b$Ue^0k1$(h-;rsHnvT*iCl9X1N8R)f;wv z-l+EXIAU0gW(Dp;=GEMug!P+J@8#Tb zo-D8jLnSli9N`kTGlIrwGj{fbV?FO~{H_-zp0+(xiQYEaR`)L`X|Gk=3^yq~b7Z(* zj2J~GgKcue%mdbgC9gn-SFm~dX6lHC~&UXfTe*=s03;B4eqX;P&X!<+E6Qs(y@Ia4kLX0QS zz7ZE%7eb8p*8Uz;xjMw}88naeH2w;INjB3uoLK8s$CF@l>vL52p4;^iHZp=x0Mi91 zP6_}5bhWl#qOFe;!PX9Ce-!VmM|ooeIkD~=1F+kGg^ht>L>yhKESIQAdS!{^IK6Kn zRicXFshM&@q5<^0ytD1Dp%K7ARBI9TqK2MAcvawwh&3H z$rPr@X@GGsc5)Vs7`sI11KO%D&ao(|)?Aqmd>Pf0cHbH>CGkC#!Y(jS54}DI(YnC- z2F%sKINuG}m}BG*8^;9=1n%*M9Gh&fhR1r2mDq|`sFLOA;8r9_bbu$Z<12~F3R{L8 zv2$bKTnXXpYQon#gvGyJN5=zDDp=>U3Mmero99YrJ+#I~L@TQj9%65)b&r zQR;6WIQSND(wy+=I{8cCjx^$-NyqxaK0*xx_mv}l&mCmh41=P;9J;VuCosEe)IkzV ztH!vNog69VFscmCIp};98|_D(d4R)!b`B}ntrTQSu{k`@srBrCpJ-HtCw12GT8))U zDgf9*KtVi2Kc{^>8I0W5PIIr#jz&^vcEZx=u)iX$k*k^q8io#_-xncZIUXfAzl5VNiz$3rM$h2FZx9spg0C|h*hnrfO;deiM)i+zagjr?|N zSy8CghSRqYtZ*?r^D{a?j@l4(E@Up+B66ZRVm5k| zZmk@$lGUr#aXHOYGsb6tVFr{xhAFHDPbn}PIEIe};ICdD;Eh2xT2wxUmxxdseKTKy zlrb736wP1p5J-;_g^<-`U^5oOHDh+80QwwN`|rccRr^OZXQp~6I%ih(j{WD9iTZv5 z3dZI3>K(zPagfrdxs^1p=Worp<4Had4tlk3{}>j6+PTRACml~~-qTsAV)2^!)tJz6 z*W3=QZwl}-0Q@}APY`p@fL^lz`};gw&Hm^3-cAO9%5$v`{k}~DOa+48LWCsoQmNE#>b9yO0cnP#jfe8lcYcvT<8Fht@bv1of-{JhtiYxn5GAP#l+5|XT z>)XzwWW#B)1qyV$Nj9T&9TqqWyU1^*^;uP66*T1@0J1_paovu8BQ)a$va_TU?tc#K zy8N5gp65BYOO>SuK?QWNNa>@4%PO(b187uk=08`?kWEn?1*3oRuB2ILt|A>#uch2V zxL3Xb!8Xn(%;DUqaUbi<`&KEJpsqaCc0(2PcGpLsjjc!B{UMt!)a~6FLDw$l8H~uZ zJyhPxF@R3;-M^`BUpxQ=miRLM7k0K6-BJ`YVzJxVweYHlLl@aS9c{EhQ>fMnuMK+Q z0Sdq(X!~%oW(KJH3|JvSs~b@_w3#;r+5(u5GnS}R68CqktN>c7WBLnIueJt^5WVgg zI?gr%hy>;7p@MdlXI5^SZj2+MOpQ7ojrk@?6{G{JHvw@dK17OsQ;kN+)e#!Cd_YOj zD|ut+-fO7aC=ep^`a%Yo12{)HYQ1Mrv{;$#zJ^b}Os=Xf1ZBvc(=42mF@3IvMhdWd zhn%gsQw?^# zs_}{lapDaz&HnU0z_FS#7~k|UdjID4O|yL`UVObB^)f@YW+k2r&Z~wquk357*3aod zSyL%3Gs#3#{rWMWN_fDvdLB~%_UnDya;Y!hA3patuwrA&bwRrKOW%=dy(~l+{x*dL z+`D4hcct1%*$>bZsD}hkEIeo~9I3$SSGD3RS`$iB>r*AFrG0?CUP+rm=N|$5cuE0J zHNORO=R$U9`{N`KKb7I4!Mp{T-re+nm|D*OowYIdhl!-Jp`8}X&RI-3T+}|tgf8ma zv1LF>vOK%|8C5WFf?Mihh4OPULvt!|TKf1@6R`gp-MJg&7Al`)0PZ{`c?P9;yd*c6 ziB7GBL6E(ndv;(J3}n>RNY!P1pwr)N(Yjs#k{He+aPl@ro}!YEgD*YW*#_+MR^2xl z(L)s%Sr+V>>JCu~;KmN`(#0@&*dh^ZGHcR|ltovYhGtOBZ=8D(onIY1dOr6uptB9i z;lTm`#|YRhN1eR(38mJkFc72#&X0lD^c@|VWA2mrA@4k|HniY%`_Z~ZPF>Wi+&uB{_s7F1iP>0>ji1IVesJfL*`B0u&-z zH@(&C76Kp-#R*Zo6t!9yH`$|j6958el>!EG&>v)}$7dSfWjYA}o77qCMEd&91k&=m zTJ)OK>%vTg8;tSMMw*m8M~lr>;VIIePw}QA1~ki6fE!U2Y*9MuQ3oAL=k;F^XmV{y zjs`5524+TO(@^m2Xd{1sO8p<2v7hpYi@D9fw${PIEIf41M*9Q#dX5cW2?Q{}1dSjq z`n&Z3*wZi|;^pyE1|Z@Nht?OTCW7ilXnUh|)1_R@^y-jx?0eVvqhR0#x8b>f_vtX2 zr-8`|7IE*%8X2MU{f_MnT#aTVS}}j^UA2z_r7@g^=TI6q(nF6q&7fpC-?OxQTbT>Yy6}2FCLP)shtYPQ@o~g``s?8$kyX#kS*N9z~&9n(>6Bs$W1_GjMC z1E(;73{4>7@KaA_c}jO4-*LBj)~S6)kG-(BwFDU&OVj6NKeJVw*Yr-I`?sP^@h%|`#z|xQ6T0WEWUgUW10mJ~-PX$z0nC`f zuOCHS(Ugc8L0+hkk)vu$wrkO>vC(|L{V23f4wTKn+lR@loXcJK*c8NAP@Z$f5pO=-wA#XFZW*9U}8Eud&s+Za@_tg-ov@(axvqctj*D!f5&-Z2$kCf3w z)eh$^aZ)ukm8P@3Oj zdAC;P6g{5TY8|O7oy_u?lNecnB1xcg9|&sx9ga+L1%o4uJMB;)79eaRA`Tl*6W^LY z8ySp^$}}Mm=vW^uN8-9bd5O~DOq6CQ_cw ze7=?0JpSosf>e229eW9WY0i{(IWtSW+m%~^%4(_-b6rhw2+=wwHq5nNx^(IZ#u zXBvrC!O{U+Gg^-)VtG!THX0o{Ow(U881smF=E#$>$)m&OBeMxzRsU^ES#lNCVPN>! zj#X(}lJxZqk?uL~U6AHB@h8BBd-#?P2$dma)%A9k*Fo;+`X4r}^F_<;ftT~vuox(h zy9SY{^Cz7*TJgC86d1MY$p#iz#X3yfsd;cp-t_yH0}DQB)-P(xuRYE{k0tLvKocMh zq!z1u+qUW-*4Vf_->HzU@Z| zqAa27NsS$xn0?Er)q!VFr>IUxwRLkuK_Xg2LXWyzYgBPbmNwDXFdKAKP1w)abu0Gs~VVWPbuTX#>tV3oUb;`gqAx- z)ttzr+IFtVJo@cCN1?1Ipq=7u&#zRNu#kGPyL_}yHi>I-2rD-!-s;}to$d1UdO(Nb z*zxH)#l+RBo%Rtq7Zm3_(!bp@eQLlei&mAfc{%~}hrzmuw2%khS8>#xO>AerY{2rr zs|tsCJ->)rd$u$}5s{pkLq&CJ&bx~1W09UmcM!aa1sk!RuV_!t zpC4hEsq#wOTC24dtrvxywyIETk5nQv;_VMIH#T z1tNfMU3&m~^XC8{bM*Xgr+0}4N1Uw%Fe2aro?+Mn7hh0$ZR1Ydz_C8)jOp$O$v(fV z`_<`8zSwNAck7!G5otn!kr3L<7|vqXno%2fkrpR)2Ygfq*rR&YQs-Jz*6dD^@nU+) z%DEOJIr#ox!!)O@+R6Nt3A)wb>`XKNpjvowCkT?-U2j2S4xrHu0VRhxS$|(Qj#8SEd+{;g9B7IGM{LGz8G0Bp}??~%0;cK@R_a0oBSO{4ff9q&z27yb}i5VbYH1*^E8$? z6sYzvajAi)x)HJoP2n<_A*<2q&x5G_B^^u{Kv>HqFp)xATw?5$HN2f}N|4xF=O0M>Ihz zMm6y59(BebE|@-$!!}sYw4aJ$rUr33<>*iQO96u+AV7UX>ZWbf{1bnN^E(crvg z)`K67IP@$)00*iktgt~07Ak}ksSdKURwH`~obZqCY$)skQ*7%89HK|?b_~%Uc@Fd0 zIfmKoXGlug+R1}U)DxS2`a~^%bOWJ~2snaex$~y9T~%XTT!6^9x!_?WrRota!m@-s z`bO!X<67l@5x9xqa@r$tLGWM)BD4S-VkQ>;`)mnQ$Lf0jxoKL#Q7Mvqkz0(Ib zG5so%AV$}3OP@9y>@NtEJBrP)g8+zM+l%FYC_;jW%9$ZqBS+AK%qQC>|9@lmKsSpj zQZ3?IJ#3r_9ILeHK(PUHdqxR%pK6^`^Z-ccD|C4W+_3x7x+_THNuC`kh>h|c=_P&f za2+2=n3oJ7_FI7rrZlD5*N{)wd$`=^iU>TdrrhA=@=S+<+6eR^Q3hV$x6Ffit}fqI zmW@<*0xSS{d5-PH|0E*FWQa|(jM^$!J8cG(1glp4b9qwksMMawW&}H(I=V9f=xvb! z0cXN*;Z1kEgXA8(kKRq+zv@kQOhe(&0e?b_tS;1q;q1&}3Wo~`QY#_;uxRI87Spvm z4&gE|=R@*{{!v3yyz2IkXY9Q$3hw!(I$^}o>#aw`>>}353(|e8*1b7Fkf-&Y`uTBb zzsM|Kltha698jGGjwd-jg9Z#S7_C}Y@L~1>YUjzUV_J3X>;W+$(~PCPLZJhcSV90Q zba7lCva2u^_jiJ-{ReYdsMbh{F<~1U?yrCMP6XO~82DTEN7ThzM!LPjf?>Vy=GnAe zQ@#6Om((iK_JR8FxK_>0?u)fwazQ$tT@MX&Pk#@#f>)(ie|80MfpP$nJpf7ZBRiHo@8a$15&DP&+Q zKkH~N)Uq`3PHFM~oo}!=a50-luU3EkGAMtDy|&M>$)BQ2&eQG)wfgD-A_I#snK8Fh z2BMD%rQy@?^l2Go6=!zuQlW*mUI;Mn=ftoMZDUt7oPzTmyS`c6Z&!8BJ2B)ogfnp= zUI|87m81IRvc6d|EgEFVVm*3vnq(1(50`6jd8=>%hzOs>l}vb|^&bD47!Sr0+|B?f ze7!mwWT0@Z;$+rD@<^q+%TX6C;C=bYaO9)IQJmj!Rd$>A#| z&g(hk_p;AgcRZ6+Y`O2AsDJjzq=$ZMpc=0N<-whLD67VKwMkcT{niKmG6yhs$%7de zyB<7qos)|!1klMwpemXdUqDHls;2D%rsvmuX#j7KoZG3l=jS5p|LneZ{m6qarKiXK ztrVWc_9XRxrLDuvbU7&H&ONxhSc<9e1;bS()SNY7EmB1OsAa4&5GZ9-Xxly}m2tgy zhpr{QsZT9Ur{)o|Ij83&PAx|ZsC3OJA2% ziH0grbN8)r;GbC_nY!msTfP3uo+54cdhc7KNpn}*NkeVLswdC9HW7>TfQ!H=C)#}P zkRH{<0-y|0)s}EGTxj-u7U}W{$;j z1N=OSZ@2ntW05E*fa-8h>coZ#can@uH@ZDjV_|AeIvF;ZR`9?YV_Rzyu)HfCZvASG zo=naK7+FbcKM(>l;Av<>t!@4qle=Zd4GSABAo(;*!5W|V=;^lFHhv=}4o4FT2&RUD zG;RJ$0<0qdfUtbK18&Ncx%DOX_MVfF`A#-$sh;SWvL9QZL4e3%wO5{6T+k6*TeiR= zO(g(nff^=sb4oJie3^N2Z68qW>XIbK>WmZVb!U>2Oy)&|~uGybXFmV%Y`3F{b>MXs$7Jz9m zV#=<)1)a9!PR(G$&g0^n-*zwaFH1|md!P)@uxnFm$vIiv8NWYAjAc;Ioj=3Sj3n}Y zeJ9e`#nA8IX^n0a$&;(mPD0q3(ENzEyyam|Be1HU%4zNVul=FxjZ};zhGxjzrvV-H zjMdS**N-weSpb+$L5RBGA;VP;rpG_pTKw4K85lJhH$xRSg96jDXJx*r;1rvRqCin6 z=p%1?o<}ZBarbWq);0pJboeQlCJ~dP{YOn3ULy4_Ya(}CBH3nO<(65Te&f4`w|J^O>=I&U7T^IuctM+6>On>=$+a%f70kE3 zu7sTay{!u_K6{W>07}GnFE?snEg80jlp|%`PD-EycWMEC1SkU3E-h_B6*MKfUD{|_ zy@X_p5_L+G^g7@d1u&lFqdEfXCRZJA+h*-T`>U9Gj;@J{{At2&Zjf6#Db;uQ?L&83 z!M#}&bBM!TN{A26cRAZJID%sqfKvrPyR9E8< ziPtMr-(wT*znq+J-c4aW5w*!t<^!4sov9DnQtu|1DOAs1)ZK$!CqzYUepmO=9fzCc1NM~CC=!Uemy_l8Lk_b{Lb*1Y7)`|bJOaD8>UxHD zd%AmsIW9^b8$iZwuN`yoUpkkg1YOhjz)ipem6~Nx`;H_G%PnhgjpP$`^99LBGCx-B z{k;1XV4e+))NIVi*+jT>ySaM{y6Gac-Yv|`^9d;jD?a2ByiSr=G%!eqYU)m{dC0WT z@u*fUqlga2w6lk4iRnn#3Sjcb*Q!JcAQPK&DNHagD`F+9d-j=MpY0o^^3Y#}{pLl= zKf*rz{|tm*A_^G#L7D_`rACI8Wb|X>zy_{{XkGz${n_2BG8G4ROh<^4Jc6O}!EynY zI^)(#U{2*6MDoBhJx;EHQicFeO`D&15d(RMCFp$6I6*^^TVDK{)A(fL(t6l- z{P8qGUf0q3Z^9Nkl}gv{)>_z|yBUpZH&9+;?LJO!yU(-InN{Z%^0w4(6VNQa#nL{b zD3-5>T%(MFBDJ7cySBBQe!#kQ`aVs?_^A`t?ji)~-7_Z>gCYuw?k~SF*D^@;tXTJ# zc28J*_xjFBE1y)9@1DokKYn4OI0Z&O;`>VK81Q60Y5-+)rYHJ&Tr|KR?PO;Gy-nnd z`F{Z1^&i{Psw7o+^5UyhUCgTWo?61I#dz;@x36xeXKL(pK+Vu~=~O1Mlj6%&4|w3| z1nAYpofDsYb-nR+riMo40Mck^v3|^P%EA_%ZJRfo@80>SxAEZAqjZ|sE~yfH+NJo;AT!M2z-DBow(7=m!b4mV#xB-A3*PKVSJwu6Sznz$A*@`Re!k zgRi@eSAugoVFDYbq}_mM1`w8g`|1}!OuiW#A9??Ew zI&$>0(Jt-DTRQH|=s>LapN0Q>kxYzRYdW4m?+U8Nx{&JleGN$R2zIeaBc>3zHXhnn z=2iuG=$vU1ZU#}7R?x|;rMiy%8*t#k+a-6;t?OivO-{H+BUOv;kZcF#_K@N}Z2Ljk zrocm9P&Bcld(U?O9QnKZ(%+kSSAb52C8|x)N*C@DY~p~IxwO8WU+}(s^YQUNt%O|D zjXtepVESWKR@=TkG!)2fP01|ZL7&11wJU9L7;L}uWY&k0TPtR}yXs^4iU;lmH1NVO zPylFhiII8&jCkH#dGF7HoWH)*7PWW{nX=bct;^iUH{K*R*Ceo`=%{J-bAP6*M4pyK z8QT(s6DQ}4kPbYyb;iU^Ma5xbAZw2(P#kSl&5vdAL?KD7c@4nVIk?EX3N`!8xZ^AOV))+VCP$AUfVVjkWFQ{%t$yUKVNF*y13*c3+t9R$L?`p-JkAf6r)=Mrdh8X3MLBiA_1 zRvcn!yNUC!{g2@tP5Tc65Yyj4L=QxbID=gjd3w%lsZDR@O?yvZp17B>H270xr}2e< zu^y?vWXTB>h7fC_UF428T`!tHhV;J4eUBsjD9OBJ%)A$Xd^#G4a>b>Q4%910TVGDB zCgD8^v9lX6k2*_RspslOnx?;W#98K|G}zSAm?0A4D2IRt5SA>f$W`@pvH~O8I!*30 z+}6cQ*9|hg)B5*RdS`Znv}soBUzw3%Ct35tugLa#7xq6w#q-CgfJ*UvLp|zIV_z#i`ca)2}tQ znL@^NBxhE$1g_v3qD%;0LK)>O>BP1QQHWs}OprXu;R!X6S_@vJ@^?*&prv>gdqQsM zJ}sa&2NR0y-7=}9-n_U)p(%Iqa{!1a(w>6|ywjKO44z=Y`?Nkl0MEC@k^|M^zyQb$ zx+_Nb_fB$6F23n{j-M$VEDat@z$WJ1YAqxUSXBD(^p-$Nd+GTYy}1f&5dG})#ZN3; zlm$qrOHcV$9}?0jh)AH1y!j=g7{-DANQrD0rwAChtE5|Hes2mdCOA{?^+TN@`LTPy z#!Mx9O)yIUXe^L@eL*jgE{sshQcq_7)*&+u(qF{kLAqbgcFZe$lT zTx5H;2#=%JiSv+mTS@WJY71821lykwSR|cCC2hg!y@Jde2lnx11Hma$`h~=d5Iy9T zwm%u`!-^vUf^xGFcMHfcO!YfWW9WsAV;H?v)ksAn%|@92!36)|>VM z@ERCRJ8PiJr}h3!A=~5zrBp1K%0(583#I3%v);#T{8ANJy#1iN=kadVNU32g*G+eQID0m4}d-MO`zvm0$FcrW7a;Xe`C$g+u3d6-fTaKMzy=G|Ivl>%37`fa(K}g z2?9E@{?wD)d{L}O)g_YvE?!z zI{Xt=9I3l5OgF+3n^BN(!jwg|LCeTepVd}V6heytZwWHwv(x-2V1Ybj=gT0=9WrTS>!H?GA9OOEGTth+=(14@Xq)KTUyA57~l1i`iRShs{uDK-Kd{_ZICkVh{2@HS&H zVc6?it6$}BX@3-QoDXW@0FiOm_$7Ll)R`bh4CId>$>w3i(#^NLz&n#6NJ|w2 zEY#HhV_EPBf4p}|WIHSBzLJb)cxEj#EoH&-bzkCWH)^9qcDJRKZO)uUUmI7tv- z6s_sqeJK3)e+p*KVvB*=kZyBeO?9OQ=L?Bts4KGIAn5if%vEr^4UBL?-Yu&S{}{CY z+`9Otr0Z;^5`TOr`eJWxBeZB62`MgINuKcw#k?wNFAKLKWAm!o`R6m8D2>VB)!7!h zT91IMCmf_Bm#Znpd|>kqF1n5YOO=k62|2zuKRQBzjRdP89t}yBtwdj4B^g%%h7-U( z>aieF*hj#daHe(_%V5LP=o9)g#C3v^`jKc_{n77C_+=Y`Z*`s)H3PgBn~-+g#M%%l zmLVrwaR3&R+}E}@a_XrnCn9E0GkSYlR_?pb(v>0Fze2h|mCb2j{8Vg)0&z$Wag{_^ z=g;=X5YkM*K{jQdJXm}ztR0C?gQidb(m=5#G7d#DF5$fQ>!B)=<39l2d&NgcU}eWy zeQ6{^zd18bw6-HMaqJ{27$2x{x;)BZQpChS;AprF*M`!cP$dU-vzyoC8OYOmBFOHU&|9v+l;vvkx_%eun%WZrV9EzMAYR{;_%BR)SpC~^fR|p8BQmAwJkHyh6GZ= z7X8dZ;2d9naP<4MLgvBM`P~VXtabYAx%t^wD>I~woVv1Z%ZdG4_vi=~-TjE>C9qb; zFUTW!ubuOTaqeS4V<1_9B)Xqq6aj22!K@0%i_y#Hr?~7rhLC}C>a@QzQZ7HM_o7A1 z7Bab*P$H|KX1%8_78XHswpxEoa{{&@bL-A4;vCdl3r7OH70A?@fVJV&;}AW3@0Zs! zercukN7KoW5c1Je)wtQi(5K}aiq3*`U7AAAX|_#lqS!APW2RJpxB-XUoj!n!!^cHy zmLJRsC{_spjMywcH*;?sLqjqj=b#R5_`|sc2^XesyvKXZKhyFY-byf}w%DxB@EfJ- zFzv87V4T(&no!D)w1I&nBF%Owr=)azv)IYcAsBgjMx70P9D5pKR>q9i^v$a5)^nm~ zMg8#oNOakekeod)Er zL^GmJX4@+l*Wg(g4U>04!}X4a1h?Tf*83{fYAP)?`+)n%;dZp+`inF}S$3h5xzfqL z)3t!00C$4o>(PQPAc+DJ@pFI05BwDme%HjX+(GNmtfOVi9d5U!Hz9IGuKl)5CCNC4)3S+AUmIHs#9y1|8K=PvIKyCLa-I75)x)b-V*St8#=ei0le~Uh zM3C$ot9ZRF!oE|Wghn0bphX5a0$|O31Gd_xzb}QB?%?*`4^Q=nA{*Dt&oP8>)}`vP z7(H|BapGxgk1~h^T2IY7x0h+{4A)p5$;;r}!P%?xrV@Hyx6}QBhT*Q}^&wn}njkfv zAqdY2Hoq|fI0rJ#tEP-?bX|2Ul53j@I>nQ0kT^7*PrJEr}`VZ)`iK5m757f z^3B}IaZ$ON%69)*Ezeg!ONO?tMMP_6YUr*S=S)Wu7y-bsPFT9ad?zHBiktL)XZ$5S zeDxUbY>0>pKurMLeBb_4oZnQ%#>X1-!@(Qoj@Xt8Pfvko0OzTaEp2 zh;b|SLph`@TEAQ1Hf~PR_R}uxI^1=8XnHuvZ@T9<;!HCTAmL0Vl$;Y4dTGR_z+D##WPySKH`xl1qY& zHTbHD3oLso`Z(GJocv}J13>2jC|P*#eUro8acI-88SlAE#sw}pydv#I&EhjXY{m=b z4jV-3`I1ek78|F4L{|gC1RGdela>agscvwRNyin-)muc3H~QtkGLsX;pvJBp^S z{mW9@F?Y-J$S7@~S}( zX%B9g>7-I`C-~yVptx;|HPMcf6+`cn$8NH}UZj0hmuyQ(=o~-qZqvF;vC3AHdXM8m zA++HF^tBX^GcQl)8-Kb=#At}I5 zT;}MwbFWs+{Fnp$$<2OWwXC=Xw{%q(v2vt0v$LF8CTDsp`sy6npE<%T23+v$@(tq- zIrWS@fSu{oTPfCd25h2F37QRS<;(S0^EP&mpB=(~PY8UUZZ}e#55j zzFJX?={Gfbx^DS7I^KzjA8?lfwhYhDm_J*T%V)-}nTW~9Nj*Fk_{W70djI>HkxX3l z*K?480}CPK5+w{L$LmS1?-pJ&6;g9OoIfYw{jdB!g1P;Pv{gnw8dy3Gvv2PPrL%Bp z?sfB!*{Wsq+26h2mV9U3^q&xDsCDi^j@JVo>%X~0Nx7RpNEQs3=J<{oeFR3~TN}K4 zSMpyGrL_|~{(W=uktT`W3-qV-a0=o&OVz^58|(6BOwHx=l8mUfdDkg;4FS&wU{&*? zR08z|7&&C;@JXdo4+W7Q3xQ?+Hdf65=1W=s=tuVJ(-{GUdDqbI%D(@kWK@x?{Wg4^ zb7{rGk0!T8>rgW8&-C_37rtYm4hySqg;-PXYSCwx5TMk_&msUNa5ggwcvb-9t^xun zAm*_1=U2g7_AF)-mWBG-b3KQ4$KRH|yzTGXJ?&z8ImNt5a0>yKGjeVuYth&~P-$ar zx4p61ShdYi0l~<5fxQi#x(PUSi5$nT;Pk04kCB1no#&BeT=eqh1->!51Yn2jAGXqu z+yN=l%}qFk^wkSO0N;<-(tH zX92r|8pa4-tX?nyN4QTzQca&-gX>qXjc?XRD38tk?(&8M0oN}D-=6}o9LGuv8cw^z zEzx?Q{lmk_%pFnGjwt7}FqHWa;L!uvT$|6Q0hhw(GfSMOwFmqAZo%gFjY|;rzUztW zxB9l-8dlC44p$sq@bSitzms?m#}dOk&xap3Ec>|4Z%Ont8yrZn`zWxv)PL|^Z0&9E zYG*uQrnU|ZUphM-d>)Ii1c6!f7!oSu_$Jnl(6Rh7VL-H|UR9qGr8$~;+$QXR7#wj> zYMo78R>pSsZE6{*H!(v)*@LaDad8aqBXtq?$%f7y%%nPf4Ha;ACq8fW@#0tKyLcXo z`Hgw^+vE9e!q%4i&W?DVM_8V%_299*lNS%_Y_?bKw?KHki}*^-EmuFUknr<5gF23Z zs100+&S?i=nD33Lh`2!*9`aR?R{-{1h%q`8Hn5W8J|U|$FJ*FzU2g|&P5s=A?pbKGlrEg%4Zvg6#026Z!06q%mVWkah#|j4b&Mn4CBEe<)r%PY*V_GM)!+gVt5f;tv zHXe4S6f04`PmDJ8!0c85SE8|l$JU37jT&K=!*^Q-aJY=QI5*6MnP?L-OI|FfN5q0x zYD`#Vzqw_P|C~1X5C4x;!;@I+**&i&bACc3o>>EGR>r6o>$`8~3^CZ{oN_v$BT7J* z-NBErm0PmZvjfj+mM6L8Zn(wVOfO0@r+Qn#JVLdhoMBpGixUMDs9@gw^fpakP`~p& zZ{bCohW}*WNL-4ms>VescBH}oPnjAsbEZ6_kgRoAK9jCin|g|O+tS#bMA1v8F;%jE<785#k~Lh}XbEbbc; z#GY*-pkpEl$?^OaifVHYcyHT_nZ6w{G8rja0Gvt>z4&( zpCsR2(*_h51n|(RruzX*D;t84{iW`>aM=ukaW#PAONHQ{*42!qBPQV`r*ZnTSY9?+ zE)U1Yz*_(TXIQ=Jh9ggkch3e^vRwgsjR`7m!;BG9>btO-Q({{<-P(q3Vwmpv$l<43 z$x-*%=#~Gz#1m>MrW`jW2)_Y$jduF_+T()rz!a?7EiMl;K3!s2c1;@Rny!QWbt+hu zhae3)%glm_u)oE0x8UTLoCrE_Z=hiB2$MjELwtwqYWqfQY1nV-^$pOw5CHLsAmOU# zz*h;j37Tq(?ISU|1tLbF)w#?pRkM6GRxUAU^)!&E#?EjlQO`<}i!u42k{u5$>4zjW ziot+fe_O!0GA+3rN#J>ecE48H`+g?w1Q*YIdjzx z180R~Bk&>*U}kQ}N}G#bnIzn85GC*YtPU=9dru)z+IFAU`chBKxXp(mT(N9)I14D! zVUT#Y&6_K+F@h+75k3NHKXX9@c>qfUIpj4U678V)gL~1(zsmP}D9m3(SaD-j^EM8& z$L?R);JX5ZVt^V+2?+|v$hIuL9^o}Pj+yj$9nZ4hMzgitC0!=0#NzSm0L-Tca{Rot zih%c^--o=Nt^#ABsxy=JZo)hp={7P_fku!*hX9pwWyar;=K#?P9FXFrt`NR1sP;8| zNCKo2zlLWp8pl$V}KfO$<9VvcPu6g@{8-M1x*xd^c#C8J~--sPZxTU*e zu*ll7v=Mi~k22rBVup(8W+olT4Wvrb)`G|^LcBP%=mv8MjughGnE*{`+;)Y(|0sFU z8$G(j7i4{XH+K1-UCPQY|46s4Ca@c>Mn(bYFWQ)UX}cvUaH;+pkmGSpgPSH? z1@piCOF+@-uVXuk+>d=~4ZallsuBqh@7vbN5a!9C@W$Ubc_Uzr?dxRaePB*{JHFqM z#|StNNZCM{cmHr~KX{w{m)%q<(X*E6$h@`W5MNWz3F}T2r2$=^oHjHn3c`LZKf=m; z0mRx5xegolHW=wi0r9muk=#inuyy1Tj8WX_{Z+I7C{OB~ z{BnhYrX}1=Gl9eMi5J3lappUPkL@oF2Tw=7tQ#>WrBkr83lz)A(u8QpB@K)7)g6Em zG)Ct4t>%CdTPfZCuyLa9d*nI+x{c_OpBK|qZ=LCV|L_<{BX?cy+(N@8-)!3)5rsBo zGeo1`^Q!5)M_OH`G?N#{2gw^uTRqUtvWeRc&7#SF;e)-qpTed%PazsF;%hI2<}SxKXdq6KZ{ zVC&D!A(G^;KFhqb4%GbaY5&n!dqL%Plcl9N)9Z*YczSZA(-!4MQOY zO633K0_>u}Ig&56U+;dg6ytq> zp9ER4WBQwxuc*lror$YY z`pg=4a2YyH2cDo~^ftGsk+i4PevzecxXKWtgV+M$ffXoM4==LAb4K7LcJPrAin78w zo7-cf)VJGv?|!!_Z`|`1h~3dhkA3j6y(tPCd!99=r5aqB3rv|TSOI$?BS7v*^ws^P zN7^KHDUl%*gsCH$R>8eZlUKbdNPyiAXRkeY-gCt}W8{Nqoeo;%V61dtRTMl4SQuu5 z1S3om2gepV8A8XjC=%JaQs60oCjxM^3S2`m+G}#0=%P*-+Ok3>vBMQAsD@ym4Il?l zE*7BAAclp+eig91^l1NOdieo@x{{chfDiK*L&gxHVDV4q85#Cql`br%8FaNP){QWS z`ex+YDl#b8#lHAlYKe>x5wi%{EbR^d!Wdh0?GrB@;p%H={tM>@>eO-f2tM?J_tQjlOrGmNBXxGwujLC zjjlq_U{VRXy1b!$=ewhd7-x8f0OpKa8ZTtq{akcxaYB0v`6Im+SoC~-$mt-5z2nQHR z@-lSmrnEs*jPQDO{x|QYd0O4Z=lMkGBhLM0S{CQFbR)Fv9a6N>eZmOSY@c6|;<hhiwWUcPxu*BJ^QR9x^F-_rq4wSw;9;9<9|Vh8p=Asmue+;pJ00#eP2 zocQ(oz|7U@-h|f2kv$<$hG#DRVgy+}1)TIwA#qL=sr_%5LvNWvhM(iX6O?uD)5Lao zdQb4}_FCSA9e0~gn<3fXjb}u*n!iP=nDC(L%b|-;Z~mAUOsDH`6(h)zY0d?`9GOw3 zY2Rwo{0fu3aD?tH;@N3*BEx;()WY*#EQMrq>hy}f&@=;TFrcLencwY0GVbRoN#xW9 zAZx{e(Ed891M;-O2d(aF?@tzemJ zc;WFZbmx+zQtCFaGU(27iQ}C|r&h8{9EvOAJroI*F_} z`~voO1iq{ki39+Ax^~y1n3>1H@0z{mXn-D0-bdGGr;t$m=PM^|%b%3;&jE=6{YlE0(UBHugSB~+Lu_RH8{?FEw}$i5 z^=lL=K5%gXHQa0MV~G zl0Aj^8(;$7iLNca{WO?4KOhgmW!x(jC(+Zje2i^4&^D^Lf)p*XzJy3!5!N(D50cpH-P0SMJGrjMXA zaOhh0R_{B(WcX5JN)Asoruwcg;Ay=pLkQIW6L9=mmrDymx z)|E-GIiQ5F`chL(AnZ%hi^!c-hQ!#OU|$<>t1sYxHNmSYs`xpl(Sa2hBy_be8OvGU z$zG~W06m{6hQnnEkwxorlmV(UQqsw(N(4NhCw&VY}cSzhUwO!@=VxSFfvHJ=Guq z1?bW-K+GiaSv&2md0-%Mv!)&l1#Z>UOHNJ&Ve%%!5XG=xXuQ6_bRMoe0!OIeWee6G z7{Ds+lJP2!+H5ExoklzfmRS2Dbl$Tm#M8hi18F=tY^?3csYFVw^_i9qtsRuLtl^5Q|bwAHV+kE|1{$?m~-QDOv!WGABbDdo+KW48u z+`Ak8OWEMn>I%TE|0vFnyUaQxAGjY4-&ej8j%O8WX5MW3TNFPCF{GQ@L!ChWmVj{u zs`Z|b(g@O7MRz^A-NqDC&=NgtG$2!*@Mbw^zIR}AMD?}l+9`ctwjP~iKxYF;)n2z# z1CR=UF26I@pS>l~rJP>rqoC#ccLi=MMpUC?lbNI)GcdCS@Za0**zNhfrC=2Gj4goU z1fbd19H#ZGSUzRjavKOivFti)Z<*JZByBO^Cu{fD>@29iVf6 zznq7I8o)?e-#aZ3+nxU;>CxpXMUXPT>+@F|CO{L4zYwsi0hj*;LQf1#ZEW>-;h^`( z#Fl9@fMcQOEJ8?DeR(2r{4yg#fLWYNGOcKY&TqrkiQmbfs*&F$IO00O#I03r&+AQd zK%SLh8(}VXqM2hjOlpfh%75U(2lxz;JPSmpsfva3ANPevttoV zK;)9-{f=3fKvpvH{}V2&3)3b zkn~nVbwfPcBUD9F)d9@C&RWQmKs%#43*?Da2G24!CLZ z)AEk02qOs34lAj%6%LeM6mh+yfN>op2$3!heB=&?6KhU?Qe8g>llquXDhxmu$Z1&9 zxIvJhMLCn!rH|^3U|Fz>s)*&fyohEul@QyX4MBQ24mJ!N@8p{uSmDx6Y5BlGDcN->5%tvJ${-0 z95Az^tsoZ*lnj_y>z+6K|EZ;nOrkCy;wjTXn5`HaAxjm<9@427g(0Yy={x+!`0^n% z09B1S+d*NmlR&UY+B)gW=xgUa3+lJQf}^p^ITK^ZJ&BmLNH1eH3A*?*rL5z^=9}iX z@@9Sm?^KmWP;abQeU-d<$lvhwZRrerpp<(@Gys+J6Wz3x=|}Vncn$Ayle-rn>gI*u zZD4Rhe@-d0cl21fAT{&BsmkX5fetLXDBtYD|J2y4_=1~{goYI=>g5}@0>_vE#6Hu4 zxbPksQxI!CWx)m6JA*Yt&bwTF@VOH%%^K`c{*&6bwi(N3V%Ec!nrt2dbW!*r~cUlUKA6 z-tG3Vmx%LzEIQP~uvXtbI$k%*{^r|EC#OtYm?!&hyDe&%GoL%D3A*Wjd$1f`*UdpY z`~$8{QfM}=1tQz);$U$R{2ClpS!D1|h4ycwtu(Sef<*(2$yWv%`?_c*IS}Mq?gYOL zQQDvH0I+m?I;ZI{Ssk(a5J~O`(#p`0B1bIdmTspQUQraq;HU%x2WN&W>JeF*BUZ&t z_(qL-<=F=SNk}GX#N55<>Y}a30f;C6lo7`>U|V5UjMM)o#wM)O$jvn1N|%ObAWd(e zZ3xhlZK>U7Qx|%>F0{_HztnT1ot1h)Vt<2~Vfh+jUm@J-&x|JFh>PP#IzKM@_E9G| z@onGN;Q-xiH3;&DV%@z*!0<2kKxx)CDcw3?Zu|rCYb3@R8cuAS;j+;99DsYoSc(JW zZ`5_5o2=hqc`}a7vs=wIxi0RZf@PTb9NY{-9w<1})klRWWOt=fP*T$W6u$B}b( zf8TyvGImqr@=1-Xd0YEalq%F^gkpN@!GqoMbBvY=0W+b#_{dhlgwHgH8!8*CPtqo&YCu+Jx9sTr?pYz$ zS*J9|1M7AWAE`&-(`S1t5m*PX3ZJ?(u?*x20I>YB7m^j=GOvw5myUp5Iv~{X_r$`GGhbZgPTEqaAS2W%yY) z8Tnyn&HTWB&Q@t~s5963(Z4;2^J359;AvWEv(@Co%a7eK_|0KA-L2@^mmx7&Tri9X^ow*R}3v(|;b7JhX(ytyIy+gISWJS5h> z@TcUm=5g=le1nc*#>oW1#4*qktuN!e$AcIb_AwhKk2^yE%*#E?KO8!6qo%bP3@RNX z;9d`+C>RGWL-TY0`Qg~4NZN{|cVb$gxvE4ne!{w=vo8^latXyw$?N$VyYx-9Kv&Lu zke1-Dw-MjBjED}o|XYUPLB8B_jCFSUkw3@3#f<^mr!X(OaqG5QuDCdMp^{fSY zZDcRVbKl>Pu;6Zx%F{c+SfOcL8SI*8On+QRwuo`p4No>Q*AXyr4QEVq5S&8FgnXrDtbk*HF-@E37t55NfIcoXKWHIM%_oky|L838z4<<6#z&7CG^53weZ@VU7=s`kz=$T$_$c#awwk0fZ_uI) zS-_Y2har?2S=LAtQMg%R*WbqK%rp3OCA64gK$4BEu zd12pMCM+y6W;N_*X0&!lHyL8$%FVH!7UJ8;@*8eG9jpj zsNy%dha%F#3dVTU(h#vfDKFNeX2s`>_*^x(`xVQez9t9Z_{8`YR3q^E( z_$Aq#7j$6!RXTJeabn{3_g<0}=4BtQ2OR2qarnKAwf(QpCiaB+PiHw6%woiBZ6)4>c`w5T|KEgnjyS= z`~$iO>u@_x8P7MyL*pp*_uaAkkJbYrg;Vo*F_h_gKS<=cbc`Agg-2}f1%K-!{fEYP zG)8m}?ex-u=chJ$Dadl$nOb#{%*6T@hsO?Q9C?17Z<{^`DVC^JNsYE7ChrGyADF#clCdr7&qxekH zHXpaTQl1J_kS4H^KhX~UY9}ErxYu~($LCXFj09AGsTAN?$;fWt>U@$5Oyl7oN&=Jd zQf%m)%LvHn26i}))n#J>1GXz&_-*Bm#vxI;;N*rL|8L&_%z7WC|VAvy@0g<92P7n|gH7IIS)IrgrMH?+5^;kQT1Pq885fzu(fT&cdjc6^k zwr9ekQPHBc7F&=1!L4X*Pqem&S`X))_Zy$`A!MGppX&CgjsUWnfvk6pwGQK-^*r1{WhJ8;mbwlwL?;6N@ic$UnizpFe2+Cd0k4o$|lxS2uSR7pOKE~Lj&%qQ;J8%5de@JvZQl)+2+ zcvuJv&Cx>_^I@)qNkIxV%AGxDm|)j|T~=`4 z+@s)g<%9}SLefCU88z_zE%o=2#}~BF<6S$x3BhG`--rNXZUN()zYI0y;Dbz8Nq%tH_kK0)K+MLiJ zuPxZYJWxUr+qK^`_A~7M#EQk_@))hRO+YprAYcI7n9(&Dv9&rd^o`c zHCX^~L8(qicfpqC5E$s|-%p4?RdAkdcW|Z5i~+#M6QlKToF=q4GYx8A)SLpYoncU( zM@jB%_rl|3KXv}{Rf(j3TgFk5fj&z>YQ0G9Iwhf@J*K@SVtVEpw|`_*t-cO2&5a%c ze_XgA{OeZC1Ls;p@~Cj+WO^D<(9mJZrVNahE06Ywv-Wq0zFx0#^eSXDknGMVD8z-7 zZIXL$n2+($d$$4P4})_{5l|{m@0!}^-*QpXg#^8%lY9}4PG+~MweKL1dYiTx+%I25 zZS!f*`^72?!Y?0OyssQTlJ6MLi3-ltg}q<`svYWKdE^LFf(xc8YjiGUe;Ya5}pQ2yMB_>2hw4b<0hlj<4VWpE{0C9XI9p_E?_5o9VR`d2NYG$LmKS z8(s$e;*x%8su6TqE&0}eORT!V&&Xz^hqO+3rbWQ~HhQvoFdw$56xuf7Zq)dt{NgjIt zx}3hzoO3gC}A))#0=M7fVKIIJm0G8f_||;QF@-> z0APO_&RA9V`C-%)_I@Y*fA7tnDEOH8D+aH%5`(~?hm@lPcX;^d>v5AV!$`dEU+m=j zvPELCnf}NesmuVj;@!fhH@9}xe)Te^Z!M6yF3vAZ!s^82`$y_2g898wGQfoTni$?X zGNEp*T8<+axc1 zJ=M7_aL=`aqPV|PgsUdmt8ZVPGMftI4>!`64MLx=u) zx_N@?oxAWjKCKg4r{I&q+@3W}O~EOj=0`V&e+A%jI?>dzyFTc!^8U+*G@lm>a_-tL z(0su`5&G62COS6XTvc@iCOPJ$ZC8*?J!;6zk|(meh3a^`DgyM`AxaiXK|9!Y(>a;XxnPC$$xg0m%VpTfMqcLxii7hd17 z1|Zz$KMY2nEQ?5R%G)x;3KxX==rl73(F54>Z5kdHp+9x>H684R#(QniA2IIb5JA=l zc#%7QSwi%qaBnmyy%7bJai3kd{%|4PTeCw*%bvYpM4T*2eZOJAI(fwI2&BE&F3&8r z{yn%~unrxv#zR5H@RkUs`!yCe0BYG)Gb=$wCVpI^X=-$8p9&+$2){K-T? zSn=M~|FW&8VvXB3uRpfn1r&`#=QA#hOv;S>(l+x2otW>q>x}TMYTSJHVye3+$qRzr zVDYuxyER=7Z@_|s8V|2=YMc0BVYg4&ug-|C}F>X+{8`g)IR$-laADUl!dKJ$#Z z{XNTQpCzY|Ex)N-bG?2Nno&8dY2}y0m@$$r5bqIvP~j-{+K;m5Z^+Rb#>dI7URFQ=z-ph2&!Ko0p5|N3POIGd;3oq1|qF+ll7v zFZ{PBhD%1>Tt_r4|+mzv5X@g*lF1Cq(ni#*uZDQbELw<95Edw6wUS&x4;P8gNt3q9HY9U*ceY{_qs?d3x5)Z^< zxfO%$yCd^=mA+5)^Tgb`?ScLUoNKW2m>o}TFzlYr^GmsG$(Tp3=`!R3&nlJV_-HI( zOEEmE!QK>2k5cjkY_U}72I~a*XFPbS?iGXw6#8yT)B+_9LunMf1Ci3CK=?ILBJ{6t z&lPHl2aEzdkylGXr5+lcS$&3vPytD!%1RQE1$Iz}#PPQcHvcij&O%XjUbt*GXVGvZ3g(yDM%+kmmK zZY4t|c6x84<@4@55CldGj0j%jV~m)TdI&G9cC#PLDO!V;O~B)e!i|OyH1qnZ`OD<) zS4KsG>CwMXzxY0zJitNiC6yhPQ2frDMGLVpR^{kiOg>&oc^zJ$Fm8#L26UY1$@AO5{of3=(|t)GAxnWfuN-SDSnV`VNSydtHz9{0J3jMd#84Cl6w;^Qo zN*;uz575&Br{DG&g~1i0Cm?CZ;{6^|kc);j4|gx`@n3Lw-L}x8M;?HcR&ukV6S|z) zlB6c5p{gW0^_h5fCX*doczfWu3?kh$Trn#7zjr+C@-!q?c9FpM;sTXPaWo;TczNX8CAo6IC)Z(`zSyH8ik(;fJ@O)~m!eV>T!~ z6d!~~)@Tsj>eR&p5E?y%$ws(|?Wy_DvTaSFB^0vwoxy2oz$$7Y#w5O;zrB60=Hds> z77ng{uu%@f!oVkMvey5A@(4Wluk8sE6Ua?-*>K)&7t^z|J@Rqt#5pA94;D%iX}`!f zWcjX0qHyHaX`7))Mv#lI&q-zc-4wcX!kw~o&acj}HhbI)M)91eFEX?c+e&vAQiLGg z#6emonf=Ky6Lnz?59;k1pX1BjQ5|Diy4uVGIjuW_W0ywzSS)LlumywPYxR7QXu4fXRn zG``w`!!E%)16arf11(CV1R4|h_eAFjkVwueKT8)jmC?O}0OJ`85yfB|pjyjo;N~Y2!yDBUu`N;BE47c-_n;0oE zlQ**fUU=7Thzr``q3xHy=P7Z6VTg9FoT0 zYV|*59N7xc30~tf?EvX>{SEZKQdQ7K2+52HI^cwLr=3R<`gu#1w@1B?*k;f`mfQ+s z^`Wkhe#|NsW_f%6Xma05P|};%Oa?nK_dkQC(~_(-58A2B*Z!9`L)HVNKyDOd03MM+ zBU*Tmz zEs&&~gM-_MZRsONBpXcIV%ll6%FRQ8+KqA@c|tP0=S1hMd;WFv8V+`mFxWzqYT|R12|4F)5~;7M!!-Z)rwg zOC%6g(W_+7X(*CSubQtcI+gYundDqBqkE3lHj|VP1cIjy{H-E$s=`=L{poQ~U>H>_wL?uGW_Z8gi3O=SHN^7L;n~Hh*O>>; z7yf-Pb&h-t!goL+lgOsosh=T2-Xwl*#*K?1`Cf2v^o*cL_p$8?o%J-}V3C7IMitsU z0gfBuVxa6N%a;J5iIhq`)AYH^$Q@AccF_h|5LJK6aQ&Qz-km|d`c-pSM~6mi$8&&P z)hjTwnAsEBAjUd8;R0(|uaPp%kDq+tXb!7=&YU9Qv<5o&OLbZQ%*RnWS2cH#S1Ef-0#H|ISgtL>1UECvVe zk+Cc!%eb<^b*OK>pt{R+#a)AWL&OcnKprVtVhn!7L{u_*sV&i&r)XHbiQF#nzp2Zx zsxv`sqA}LZDqOQ9fX15#Cvj|S6f383V$ty)7 z$aVeX$kM--uuppJ=KfloX5G}H8D1JSm*7njSq1+iTOhErT$@g!G|KJ)nVq?C41~4O zkQs<8dBhrcw-Umfi)9+4%z2|e$*OmekaG#-JqodPWV9K*Nb+?OMl~BmKpK*dN0W-f z9-PDXnrkL6*#VW)LK^(b=n?B!G=I-x2eCMx#0=v9-#u5-3z6oF_xwg56G3ZPc?C0J z%hXAOcU1n?D4rz3XB{fd*EXNJ&Qar03|j<`s)%b1*|wSpykt2eL9xuzM)<6)PX1zf%5a4bAJ-<-GpB z@k#=VOD_OzpCJ%1DA?^SQLw#YiKq&Ck|GHb)eE*N>^g@z6T>(Ta z5R`cWw@_83r1|y_0|Ss}ObwaX&=dqt8|+OV*hw=`bZo!z0sVJkzohbtsxSI}8hv9U zD(LVJxaSFXka#pCxh=~tv=``#v=xl7&8Y9P`IjJp8HJZKSyhYrN1o9u%W2EWn_7*E z+Me)Y?px*_xg7}2Oexumx6|jZlL4axKS-4k_AM#TzgyBliGj$1Z|MGsteL zJMca&8T9bTP)8*&at=`-J(syG2xOaA^Vynd{J%wJe&A0r{%uBGe(01uR&0R$Mu7Sp zy~#aGMD)Y2K#qs{PP>bjJN^-V2H`_y-O~ZC+!HkOsLW& zKz21Pz30(&fPE1)(l(EF#D2{+hv-vRDgA2FdbQ<7t=aR9#~6i=A>O}}ozWNGJL~Tc z1W9YJcHQ1Tn{3XuU#m~s{$FRfeWn8!7)K|Nc*%}@9&q>*iHvKbA0 zc!pTErC8&(tyEsc8^~7-s|m(L7g-~nyZ*@CwOis(vaE$%tZibH#dLrVq|GCxhlg%H zGKLKvk_}$*kB;1Fx(7_Uru)gUiK4bO{@Qcv!jvAl;NM?;%Hy(NqMY7%u{1nKEdWrM z#b45Y>F*5oy+5Wm<$>zrdw4SqJ83uv4jGw^H2|n$u_nhjb27 z((0i{`-9A5L?g=FH6Wmg)pjA5V+PMUm;FMt+1qy;Cz8{H1`_@Y1-zP$Shi`oq-Z@U z3+|3;pBiacw)MH?_~1P;NWuYgtdI^U+wOPm=V{W~%fEuNrknCIMu7(v>;!TzkXPRQ z!t&zmQXYlLChTP&#Prp76bQC}-Z{&IS!2wfw7daQ;PrE=0pWwvv`_M)YZ6MQT@So8 zngGr5D=o;ohc0gabno7}a5o*bxKHRLOLVe4x=&q_^Vv$a;P!!UX%>+JK|ffKL85Q+ zY-?`pPrN^%iTk!UTIyHu`gP8xa}FLuZEA2TX=vR=+SD~$z_Jp4)0kFTc>e0gIk{-s z&2L*~3aX7n2N_fe6Zg8ulSo=xcJLOvvt`HRUumI(?b4`f-u(Kpw0=Ra%s)jl`;9=L z!G^L~yn@MnI(sT_(D1T~jsvqT0CSMi;+@DAG644oaj0hb4Iyn*6q5mI)?q<=(M3|) zuvaXY4V#Sz7cIvIlYc!zCZ%89BL~oJ7JCycdp;|y&Eu~h&HkkzJKM{^dQZypzCUKt z>W=UDCppYjENx=8MQ{5 zVTYr7=S|O@>|cQ>?P^qHfvBIx``O^usqbga)s)TGR`lSBo**1T33BW8{H;s)D;1;+ zWCJ>Axto@_+2d+bq9euUTG!krB?GjWWyFCjLpY8eqqW8Hl78sU1Qu$Vepkdo=IIBN zzp9$S@93PD59l4u>}bX1v!>|kGdBcIpbS52_r}!6+a-6sr?O2mbd87)1n5b{be15= zsC&>jw-K?hxRcc8(6Z4t(B=R3F7bt_sbXPX-@8?hOUncBk_L1#*`LevHQCEW+O5ua zua}ccw*uwiRX^n{Y*0Ca%}1Y(-(8wklS6w#9o5>ET1kjxBjwbUf4Zj7|uMWt#M^mv)>T` z{)zf|cQ0QSqx1;cNrGeID#!!JKR?`WS#;M64XpgP$VI^$*ye>&Mgk2! zu^wKxo1f8TfgT{juvpSTKiNILd^c-WL7u7hv4zQH1C}9=K&Oo@9B@tTu}l2IN^bw< zg6u+7PFWoZ9}^x6Gv(1F1aDIn)|rN~uPRE7ypnTjf;p~;o)E+)5TXW(*Qby3ST zHE&H#6MrB7!!A5;Lxw9vpH~wksz{`qLdB8b=>?iV-v~>4k+i~Of@NgEGvd3tmOn0N zp7l{x-LqRU{Rd)`=X3g@IW}oN^{l}<^=RHl6NiY6j@1PH5Jq8XS0YQI5R~4~6)|MD zBA70k7-b1-O#&64*zSDX4MfplWrJfpzz(oE=2Di;dm|p3Q|>iDHo5=A3%%~$ky%AK zHEF=)eQ$w^SF$3%jS<_yc z8<1$`YDDL_9#j_1dT7!7cHYmG7&vEK$OTK8LmSOmcL2cv5Xz)pXq&JQL|7kaG9pq% z=k!BBRv1W4mXYlhR)|+$c{7pYRxLhwk#27aCOrPr^%Ht`EROs#uMxDlT~O2=DvBq+ z4#GQ-?mgw(Rm1=vh9=Ae{CjjWD-8EKr~_MXS=x>ZN5v5yIB4^LOYnC=H6cko=cFX* z(M}c-$**L$N~;V)M2wUa&uelgmmJ8HM6KMzwD)Ui7+~A@lM5@7cV!MHUYLjjr zgw9L!0Kz1xbc;i`m8EK>j4PnoyFouq#BL^4fTd7>v|SI6W%{Y26x_Shtyx1Np(*RgE7D9f=qKa{3R;e-d? zGqB<}-vdcxvGaIXa3aUn%Bq+iYd{vu*#fwKD0U^>k8-WPEN~Wf@}Ck>=dsrWQy2~TBhc@4s`S5n-%PxUHPeaK4 z87&}qY1csz66o^So{F3#1E_LJpn8(55Xiw!sDgq*N+QrPUCf&~E(m<$!pFDsuzE|) zox3d(p|~&d8>cYBd>z)1eNN8Fx_i5y6S52OwRhI9hqt7{87B2yU ztTUpf=f6XtE4g)*FWKPVl(mUvEvo7+By7~OEv-FbZ&>bPmo!XjGwI#{ zGXU$dNe1WWSs4tRHY5_>s+pnOLLK?N4ajmJc`DvNEyF|3q^yh{O-8BwwN%~vUi%2Z3*(ePoF)98W(w65;CSM@_w;9FM5f6{Ahr)Oj zYlprkvGXJ{K6(g!`yGe}mII0YCKg(~3;D4OjF47yNlzEK%(>E2GV*irz=vTxTUaf~Vrw@K}$+Jt_WfHWI<*Ds_i zSevhU&BAIY{T~j9in~4%S)A}-kP84VJI1o`fLc$&{Hs7vXo5xNODVwF5B2(fLefZ} z^3i*gN`5e)5QvTyet(sp@{xGI%5OE+N%~1$ z-@ye3B(mSeLt<$|t899C@&|up{(kT6wR|QuaK4WTYICg&Y$wkwT2tZarTI3eI%NKU zHK3^l&Sj^6xw{$Cw7cL0;90!5(~s9zTc^(1Eju?8a*%Wo?4&`9%(2o-P zE=~V0bi#;JvEj37l!|MMSO}o`8q2-K+_;c7FRXWGhxMnz{<C^|wNxl-`_iBU?;zvLCJTlrNq z$c$Yh2s;i{bb~;nWH@GF7%2!N#}J(A2;>}~WN_)cS64D4X=$`4%Yh0TTpuq*mRE75 zRo8`^T}?Zchla~|Fq8Z@pyCKA?2*h#QgALwCg1$VNh$7%$+G`&O_5B)XIy;1oK@2IBxBeF$*(Ccc{7n9M2&aK-s;wKUvcX zZ=m2}KI7zq2@;Lvatbb;Uc{niBE6bMCvr8}dSSCGHu>bror=1GxRr_g2~A>L2bAH8 zOMHOBp3o{n@`@Y3p>FAu)un7b?4U&)ckJBtnneC81L9~Nt49>1k&+91uG$5$Cn0t1 zLc%3|LPh1U|98a7e;0RVJ&lHc=08yKh7?9}g0nPymEhAIa;z2pjaAXC+)f>M&dPZX zU?VL$%~n6R#x5mMJ0;@Xnw73I>RrA2dLW(T(yD1-oiH`hmT@&XhgTgNsb3p8X$7IE zU{Q5GUJLo9l8(_Zwy1VUTmt0lJkZkP$ID1ayXFfewnff9l@2M@78BH`=?Xfz621vK5I64#I%Zb2pzs3sD< z*M`ogq*+9+6GuUGmX&=kCl5^|ut@w#6)V}y1!oW-l};dlsDKcE<3g1$@%w$^4tjlw zBW~|7h<70cBzA)pxPW*o!K_mtn+vg9-(LD{7c_@!>@FJ@)g!#rmMW)aid_Py4bD@7 z%I)0Y!clRlDt<`OIi%bPAi|nVQL;h;m>;jL%y~P#Z!l{J))1r4;8Iy^9-IfB&;@tI zMwn%#56Ibz9^YBDR(SHe7>g@qJS8YWgjE3BOiImt^Oi;}eR&DF^xnFwgsOfz{tWQ+ zMdceH&CC~;{C_;DaP;yKYI$7s^0l$9{@Z<71|n`zBvMVo%K@gYS57IzR=Idms=fin zlDVuEgyhF7*<-UO$I;R|v^n{K9BhMOk2ssmDW^)0JZTk+H^1s!>H8F2L|}!n!dH#N zQQmXug6#x)L^h>12RA=yoe28pKa=dEG!3rZksO>u33GPS;Nzu{ zgs46l-}x<+0fg20=srT|n<@NpbOZbnS3jA4kN7PT7CKGKC1Et&3CA}-V3U_4Jdf>& zRrrldRR*M$0KUWGTS>Wt6q-4-c&cbKhrhmw#_C0x9vx}aE-$%z|jv>7Nejn-lp{bar7Fi~(rU)bQ{tu%kcbCYij z!+k5hi0qkpKQC@Al@Zz7CAV~0UuF!%bv?Q)uHxYOC@}CaC79q(nwi{DYljQVqaa7` zSJD+5??0S8kqF5~g3Dd1&bUr-e$ub#eJv*uQ^`d%HjzUrT)VT#{N)7u`m(r_mdHyb zqpgIxfO1}-kZgj#F~WQ$3@=u5Yt@rz@Uj!LUOOYEgdN;`!WA)*fI1{~v8-&+)qewI zzjap7!H2b3-s3DPr2f}BAHxW7T236wKc%2bX_=6pEQlrF@*C0yG86X?`QIbN8){Qw zTC%F<@+@$Dcns2d4ZckTaHX;>t_zujfGxuQStHGNL6H&RuRY-}q2S;M`4%>-|6yaQ zEB0xV1hmYn0^(v0wp)Ze@d=b-$)r?Wi9UIH;z$x;U8-^*ze*^6BSc3`-Ong&!;qqy z4*5a?dj}@ew8oZkh0#_8^aP>kGP4Nl5--o9N3m>xRTIc64mrOgswxsW61Vo*;}1<% zmD0c6nvLmMaUPIPB(Xo=>Cuzu%O+$g0b77j+|g!W(MB4Z<3f{J*ylpc^an|7;xpH@ zDR1Wox0W@LsNRJT+~iP_Y$W}^)(9T~+(PmVbW&myYP$Sk4u4Q(kGVi zMEfYwpSfr@m7mSc*wh+Bx6Z4+iRegACWs%w);-A^L-Nawyy_oFKDdLl5^`OON(-s8 znXBJ`dycek>H)8%%MZK4=YH?E5BNFODMvV=olx2Vyp+N^>6!O?CThRk;>lm|WX~tpgG* zo)R`UIMD^lt*e&E4CYjM4^#sbbd|f2S@H1<9u&$7M#F$8pT_ej^JZ44 zp=75^LkFR-G9qtA-;DLyhnu>?UzdirdV=#GAXOyRi0=L169K^r|8&h;)GD-4ibbvc zRX`mb0n3e&YOue6K6P4As&zF78{HQ;NiE3fq2wJzz-D4ebY$%L8LOg+NT09p`NNpt zomsW<63QjXa`hC%o%tqd^M7-fl3+Qnc>j2O(-bmLQ0o5GDTt(mzvT)9UH8nKaDz2= zD=W$%Fg#IMNsFFO<>~9f2RF_v<)2o7*@60pv^7f!;VKqN{0^z*Gl(v^ogH-Egn}tP zorSm}VHY9XM=q--@ml(16=872k1n}Yvn_}wKBPrFg`tvtJDz|wKc6yi==rY%TU*zv zXX>MGh({?&XEv;7$6UNU=HO9ym=-?2A}r&OZd$T%9lkYL@`zBKZ^j#Gyc@{1J@S2o zXw3={U-Fpw|9vOY*6;m>7b7}qcFmoO89@s^5?m^aJ8PYnaY;BK5=2jHsa``U9=Ier z`2HH#*Oe4XJ+GoaJ31$hCf@gkMRl#Z+`vVN9@rxM#&*@dOtmI-x#iGuKd|aO&;_o% zo4Z;&79RbcP6ThwA!0U&-+3@?ok*WpPRnl-OM(eOJhdh16DXKKmXiV<-&esR6nS9@_EPrcCyJdE@_{uemU^@pV6O%yC<W zs#zJ+MJd8)K?5P*M~ZkR@H6h8d)Q>*J$a@ve7{RxFvC-~b6goEpEX!hzn!`n2Pw%l z8P_tl?(I?hd#Q>N^hc&OpO)zIN5_^DSTif&{{VaBZu0-{hZ{0X(8=JlSd4dVECbGxafMl1=;q4HmpI=0J- zGVU2QS8K71vE}`4-w<6diI`Mr_E33RU2LH)rF@04{zS^6V(-gKcDpd%u*LPl*eeMt zkq8@*9u$^_6On4ZH9R`yH5BIgC*Im@66&jAV z_)x_pzAyhh* zYXqT6($@`vKp%r?n;dh*Q*N`cq!6sXEKYWtAS`x>w4urE&DpXf9I^#W>ByK9v}xdl z4O{YDC5mdvfb#^(G@mh!skZhAvjr_v5oH6-C<08hO&rwWVz8h<&1kngoU$;+=mx`7 zzqbib-=QFf+qOm9U|6!2jvc;_C{3hCvL!8Bs0?Xvdp8 z|BAMY7WtRu;`cWbj3}nW_o)zbnv&#N*G0AyO7vx;3Lly8r_%5c+b{8`enWl}*(c|# zET;v2I8b_7U)}Ivx_WCvsTH1Am;eO}ayVvuq-Kf-rD#fkFQna4|2vuA$)c|r=Hy51 zb{BPvFBhF)gh7VL!;07(O9|Har^}{jngCCaTDhgg*iHogyh;#KUX&M*x6$)qz^2hd zHlp9jh!P_v2h;5PH^wp8|L=R9@h!5{Ws_1g@`uP1y4?MR;igMsU}@er{ISjT?~H$q zlW)EQBN7uW-pLm~Hj+CnX=n`Q-GTk~9jxgf@TBT`lr{;*r9kMHMGVgmhZ8g&6g$g^ zH{7WSq{zDW25yAHE!!s9Bh@V?O=uPeCJeiz3-vwvycSV-4hvoW$trEe82=7(i4vbH z83SNxwF`;xTJdtKcB;v!E~a)#j~`{kUQ^a2^9-xOiHfRS>jRudU56X8RTJW|!=A|B zC`LpQ;_~LHszoI8b2ax?eGiA{DA5caCa{$qMN}7jHHps|`6g~^D0#*nod(ORIh#5Y z+Llp%Wv9wTD~q`^%O6L)+XpW^cYe&Lh-KU&&0_k8i{laJUH+ZU8xxHC ztJCOygf0yraq^o#4}7*N#kB#|gA&~LSw)fy@TEYAmSpVy+p%U4fZHRIMq$tw zUu%*i(ttQsBCeq6{v@*WYC*Y-)t=4NM?z@+AVzIqLcwL>STBt7=vV10;XB!bMtt80 zWW*!jaSmqA7!#I4N7&=;mf%^+nz4^u(ki7Dok&6<6|`(DVbt>S0?{-V4RbL{T{fC< zn_Cz5_3Y881cN=E=b9z`viJx2G@BfphJ%J_>EJd*)I_VNIj%jv*GY@tF!;u5SdCL4 zD0O+@R~DPdVz3IkB0#CBC$bT2yT#857F-}>i{Y`RA=xIkl`9}?7mUU>s=aJ@ro{o= zr6C_Wb3&>H&~fEFDb>5SGS$VX`0Ws$KM-_pz~P!~Qu8@EZNE~K(^0BAe1Mg=k7I;J z5{3w#*6N!dG{7g#+nF$YsOI*srRZ=SiwurqPPt~+Uf9i7ykEQ{GEp+8HyhoO%Y07f z2*-AjEtKISk&~LJ{^si2Z*rK>L{7ZLYZRW|?U40vgcI6^BNP45(5<+fqL6oX#` z*diaJ+{$%}a7du?%|R*lNkAC~-@lcu9b~b1-ExW23s3_vfBf6^!eH;GP?t(HCR73Q zfK^nPgxo2j`B3XaG;aVJ@7^!kr0+rU4Uql=N>HR^?4v~er07yn-#Iw&6}>IK5!tS58-2WYPci?|=_UHfNNfwD*wqWD3X%Z{#u_NLGRb$QeAblMowC3-ip{Lq|t z1{`vdjvJ@sH$xo$4mmZ!P|axc25VANgfdm%ElKyC663iy<$T#+`@BfuMVDla!|LB^ zLG|@HIjc{x+UVU-^f^M5=IF-ClyI=u1&{SEl6CqZ9nA{-&17;Q&%EAp8$jOU{+P!x zEeC=`IVTK%(spS+s~UEZ5qJK!eYdp*TR``GvDAX5-iE@8X*T2zhbVNY5N-0k%da>Y z{GL^{PTv!*TQ8j5F@%?J*;oM=S1_H;n8YWV@4S$|b8dxdcNbE=*~bj4iO5}WG6zOiK;4ao z1dEy{5hfYbDcqI`2KC~3SS1A|8K5LS$LEBr2=P0&MMua2A5vfvBsvAT6p{&$7*?xs=8cw^&~GWxBmyCP@PejvjS^YaT~ZE1 zBJE5i3-OS*FaxYnLUl@U8TZ~M7Ad(EQ3SBTF;{2Yr`#T7bvk(uD;GGELXxgb|isPKH*xQu%^^_Wn1ibH|^z4#V>C8 zJ8s!8K=!LosAiF;PT3}>>>exs%_&p3;{xMK1D-&~MvJjhdeK-VHjZbL9*5o@Ah0U` zfBpPHvk>M3k#87itPfHESPQWAQ=tk7S;(OXhArFYf2k;(nZLCA#m1;JaAwMfe!LZh z7qLkY@D%W4rT_5ZBzi^41n3ETB<+FU(0SOXQRmJ@#&F}KZbFGl!5A{qp94F)@{VvY zui(Y|kc|ecQYl-evgg*5ItGk|csTaeNUZz`@$+mCIt;P0>~ z-^jC&I~-2NrVxVgt2O&?7o2&7GA1O9Lc*NF=Pd5$gKk-5kQ0qi*6J*4cL5RQ4kMJ} zA4o_>$^NB;Mi&v{79e`h7g{SBs*oC0STq|gxy3JaA>I$Tl~H-e}OQbM(%Vc|Aqy;dXhMDtK$0d zv9DIOL!`x`0*$Y?H@@ESGWK0mxof=$;++itDEj0^cFNw@+X26?KAWQXz3zZ5Rm8`} z-AOThSdGb`ar1UT5a&uxaV6?|fQp{-adu=Hi23-|hw7>gQ7@pGEA$F-Oqd&5IT#UQ z5h&StN3x+j6(kb8HO39+^rU|9c#MiNMtTg7TRQ>`>ai4h;{Q=}rcq6teH))GlaM_a zAcSoa0*2KHK@q4>f&pRw+|K?|HaiUGZ`b=TE+8?q_1(m5X z-ct#72T`^OUopJQwiV72;rt$x(z!dst45ppoz4d0>o z_kw`Q$@2^OLKRr*7tCqlj|TaV@VzmKU_>~m2o}Z~4x0?{>xB^+c+J7z(<3d*|NJ~$ zXixlfmUg+5;kqMQ8(>UwMTPHUsGW>u!BeVWMas$83z_@Zo`6(zT;Ejpzge)@#Aq_X zVZ!<2UAnw7FdJ9vEb%F=jZ#zjJ2voP#=O$;??CLZ8CXEax9O}mxWUA%H6k-D@JtoE z2w-jnippsKyql2&ATt7Rp$m$}Aaqfj5Z^9xK|ef#<^V{93r+x_cvEdwMARw*yf6K9 zc0ZKVf^6BCc-|ZP41`wUFd1aM?q<{r`HL)hG)7)zVEsn}MiA>RxoGestJ=xCA>=Pp zK`-2lP+>?1#=ZxziwsQB8s@L|GhHpKYK`zF1G>S9d z7RI?}lXev@i8Q7CySedO>n`onS=IZ{wQjW1m03fu)&vRWWr0F{xW|#?f zI+!0_-slK&cL!Kieh~D-2WJ7rD&V#u1orbv0;oQK{(wog_*tK0wTfb`HX*YJ zWQ9LF9)O}FL!LMVegjnO;y(g-pHtb7Fy3`HPqUxb_yQwu@BQW}xR?W~4G5HsoZ3fA z$+0mD;zpbIU*;SPmBWKn+`DRck+AfQk>Tx!@9nAt ziVvr{9_;|L_dpXc*b>klxus9TcZQqgiSkgfCUoh;v^WY%GMq`h8RE`JnHNAd2STLb#^|$?(I!^_*2Apk$gSObDw~@VzqZb)iGD zs4RRlSmWmTG5W*y3wbvJA+KW)Ic`uM*FkG(taR|s_d?zf#=9rvm*dcm=PbXQby({@P#&s z!=Hjs;kw5QS(Q%o@osdD3tcTw`#u9ig~ivp=C@z!dbzog1*@P{urAN=tTuy28pxyH zC_OM4^aGsW$Jg{KOS@%H;pC6bOa@)W&;AT&0lhyv;3^9%kwQ0HAUdNo-+K&osaU4> z*~GyJ6Hs=;aBJ)p!Q>h7Dkr<%EO&$N2f5$Ww?w<*92ioqVoq~Uh&IZXQDCy4=ADC4 zB^$T^T~6^kIIC-d+CQa|+JuPEl$fQ2fhEJCkU#myng!s|vG*%n3+3)vw=f2~f**}%wO#C+Yvm}J?ro8auh8BH$WUMN7%qSgZ}*uY5i z_jm|SOK?q7fVoA8z7k}t3^JQ3czS?UN+e>$Hh%EV4HN%yV4xo3mAk;D7G9;3_p6(y zeaz42{C`~Uuj=J|{8X&pet32{O9)P&YA*D?{tF1%R%tG*+^^g^{CX6+$^|Z*;CcO@ zusHuLMM<0@>hl!WdsT@o>L4Tw026h?irNpWTZ?)mI>MIg!aM3G~vl$#&7Uot9qh7_t zRfxuoxB?+_DEQpJJEFIx{-=M!m9(P?S!sjfkG`mcWBiPbK=IvWEE-rlaFDsijjsBr z%hIm;W+3(V@W1||1F}wNs?dOC!WMt(r+@#vY&;s6W`_|cv;o{wTJ+z))xbNG0gmCq zwXxz?Up95J&jXf5vl~SSU=q>ubrvg)LW%-VSS--31H=CKINX330LF!z8K0D7$At2m z0{(u9Nq#TW}ZWw+jJ4WX0)oSGxI*tvn>IcakgEr?<{g>5=A)g^!MH)7LTOU&(M)+R~Hf)Z2NYmBClumfAF zML*X39ubpj{L3_ShLNpKeQ5@Np4BAXT#&H$>BCtEbAGwN0*B67?I*EV)uYc7thKSS z#Qv_K+8mW5z==z3^G=xmQrUi#Q`knED^r8Sr=08nQ5zZBT(B=v4LpU8qzBt;!=o#i z`-)E&&ywoLgU*6fbGavxw^@5aE8To6q;6(qNAc=|>k9Cdk(*W4`1cEFpTheErTy{3 z|Bkv>X2`9q_|}=XTrgkv#A9ntD=<+gdacMrW=0xuL=cW+1n^)R1?0x4Oauuz-0@Ov zVSbbDiBi#|WzCJ}MHi}E1g%^Jtk7^OPP;ihtpJ4fX96JK+zyHPN~$LMHID|Eo3)G} z7QN~r3Go8csqYr8`lAGx_SUJ5--BXvU=bv-eD~_KcYat10DvK!NbPn`0NEBZ(ZpIn zfw!X#dWsNBEtG>9Yc$yyp$3YOraQ1&seCX12{vADr8_uw$l2q9C%6^j%OsecfsMFq zs{3ELV6-4e2NrTX@y!LP{cN;? zV3(my>thExnguri991OWfrW3ix_wr*eXOH4H+n}qycrsTc?G^Ms&=RDOh<#t=aM*07OWVC1$Y~l4iDD{NE$60VLdCfv{|W~_gueor&qf^ z3sD zAAE|h_{qvB*5``^!1d8s0*o&4hF&F@dBe3wVo!bG&%-iesb344nIa0Y%af*kxK8oo_Ll;ThrMWG6z_3G9D zYz_j5M8;-VaJYuP$NAbDLHa`UAqV6c1HNPDT%4q8P`A`zW7Fxl_y7?`-#K-5E40X= z;i^q!1ntBAeF7am15iP^2F(?MQT%45+C{P#yP&WsZg4zV+!0gmrfq_jx-El{E5l`@ z0rhAuMT!m^eBrl5EZk0t?r#O*91}xYDx=2_S7W}O98y>UMSRfsI%~6ZKGBwbt6(P{f(JUH*L5vNAp+fGw#BTR{b0(NtVwN5kUI1_ zS?7tyFF#SGhdL)R9^H!gLh;i94z10}H<9|e0WdVQIX4S{MOQG`ty#_pZ?5IjcIb>T zfF6;Mf@(Zj5Cn=A7W70AGEP>tCPJ<2iS)amyfz`K4?<7wxZth*Hcp!hVt3&zalemU zLgPeE5=_N}9oJ8;JjovK#NT6lc5PF(pN#AxIyWT zwjesk_#krnwc4BypZ3Y-tQ1e)hdy+*>13JHq82o+P=4)1q>Wxds@{qP10mU+L!31u z)`>y9PO9+EUfU0a5<(lslCj?Ew~BA~5y$|*;kvbPQ-vT)m^rUHgUW&ZMAbYcC*~p? zOV4bwx@EjFeg52=st>&nF9NZY_(;!=!jePJiXj@I7j1$kO{SokN`gl>7Fvzt#}d=8Z$q14;@7!zUJ&l`EAu$2`%-!f_S?A5~kuHJUKlP-ZOP%utn;9s#aw;CC>7-OC zaR^6@ev?2?!HR%(`s4ud|JDgONU`n`tSt9R-fEK(#&B@w^#C)Z`x>gYll;vXQr!Qa zcT?PrRmwot-3>VNc(7)5-YVv8Pd;<>4=@s65%&jGJ^h)tZ&T+=UOeW_7#0EAq6flz zN@Q=2o6f*esPqxeTR>SizLW_otNS>|>0!#yoiEiYoh1C}sr9ZfN$QgQiZjJKs=E7F zF85MQH&mDVa)@7O?2WtTuGMx|M|+%b>UOY!>CPSpmYJHT-vFD;RV$-9i1X{?nrrEY zXJ;TSKPVM~TbnUR5qw7;?_7sq7!|H~UF{WwXV_y}v47JR4F%}rsCOClo2{8WO@TK|xGrx?Ax;+bGd8p_%0?K=h?){5gLc#MvWW$+Ddt>M z&!2syRzrGsz!+Bxp+uEwkwQKT->8Qrc?M zM%V%eA#y`}OA^n}E)iMnuX{^@BZky2$&7hrtUqYia+=*pt?Lt1&eEeP?V`?>=K7v+{)s`AUbKLvWEq1cio(tH0 z)|$86%xx}Y1HmYDFttG<%n!f;VyDfnC+udUU8h3i*4ma5IFG2!W7)cgNbW!b4{oZB;H;y*3aWz7#G-4nD^ z5For8cmCQ1{O|3ru5Q(pj4)tl89t zVn^{!9xCB5#X3;C+EAx4I&>_1M!#(mVZ*WiFD*;Zzn^19y5=A>eQt^yQK;xI7K#~z zQXJOI$-536&U~eC0$b8qwQpC>Y=zD`;r1Egn8r4#1${4Aw>?m|El}5xTbIhZ2;d3n zePVzb7cYEta&`z%!m2f8*Sso3=zK*u21Mh)K8?K6AGtJ}k5UB&G-dG#!A=C~l>PEY>4_sKmd^h_*s0*t0SXUWq zl}5S*kDvqp>8@{;ZYV`U?sPY9X4y5a!;u*D3TB-*-o{7>#G+j!!(`hE)DkD{bQ{~A zz~reO#yX3AEen}6W>XWkX-=frR9oOc-Uy;l}L+HW)T!;k-Vg!BWUZ2)X)`1Z|-EGGbGf0Q`)QlByJ?|x##FC2-! z>Etb86k$)J3I!D5_%2RB>q7e2Sj*Aj4cOztrbjMP5cn#RNFT(fUvYyH7x-?&%qwl5 zzLVM}2(lgr(!=!Rj-E-6T?D)BKRh9sLszGjb(xa_0S8vV; z!HjfxIMD>oDt{W^5=o4YL$qYs4ZaC7~BA6s)-ZA3r64) z44yiIPHnSSE17SOFkbwM(8~NgPtBc8$Sj9_?f^nJ_0Yd5r@)-N@El!lUf%c~@qjH$ z1s^xt;=4|;JPEJ_wgw>>Cnr{0VH}(_duUlt9E~JhYd+7KOeW%1+Cmnd4W1n!X4wH| zHu#AhAfz=}E_y@yKoEplgkZAkyVp)x_iN$x0d#HlVz%X8WSdPp043?+IFC(I9W8R& zHw9`7RQB}#x=mxY?QXO{bs%BEmnZaf7hbc48qu{bYnIch#ce4Dq{2~~s&7rD^D4#* z^wQZJwBfi-@32M+8AqdP9+NP=AEPS!;gvfM-$eD*wg3FFf697flhKZQAf4VuFX1c# z(VGynhw*qLGS&Ru3Ad>kb zJl~HncEB^b;OYG!O|FgXf`erh64udzHds1l73=L5KlEi7s2Z_m_{oem+hmn3rr(;G zZG8?|xqkX34T@fcV~u3=uAQ&^@amnXWZ0XJxJFWL7|b%+u(2~=@k*De~E-oW^wRB;y)4v9-{ zXw$k|7d*QU29gmave}HTX+u|0kiHs8S246o20bBUxbDB}uN#VevUUK@3s_T>p_?pq zs}0C1CBmIzo8E?`bs-b7Z;|wESGuvbv z8MhUfk#LW7Ryz6SeXHQ}L{-q5X0TUDzEc{RTl9yu1~|)8GtD`?4&yDGqyq63W9MPq8&m`EdlxmeplX^PqXOCXB(8aJzmW zEz68>UDohd6U-w}NjO7qKvpb5^x4R@2cVj=?N70#wZRe$6k_DnW@lnk!?a4EDmLQV zcYxk?G0~S#2)cl=%iJI6`{h8P@Ry;Gamn6;o3L05b~`@uVsGp>=ZNjwn<9hIFB!wX zP+Vz0-<7s6($8D^f%%^ez>d=dt3U2H-=rhwY>Mox`4Yy4zmSPtWO2ZzP{OeTQn?b! zmW-Y>f)YO%Q%&x_&z0zH`hH?{Dj8To5)mg-S(_#U*7;sA&Tn1PN5*(;?H?X`{#I8y z9l>k^gl3_zT_CIkzuyE;wUF`m(3qqQtgLyJ>YC!r{ndC!&(2!C-sU1;G3IK&0Y)g{ zXY}meU{$zni8A@6FKTl!Buxp%TV_s@fJu_@oRvV{r3*>a;g2qSb}5#HFTJmmfAT*P zPBpX~8)3d@V2-~lY}uq-$5S0ClqH;6oG|~#RvM)mr3ziu_W$bqDnsp;AE0ZEaOwc< zDMC*ywtE+&9}l3jSJ>8%FxTG7-a^$Y&WNTAKvN9F=+6A?!QBWBY#m{4>94!?E&8sf z78U->8mZgf#eADpXYtg|mLQXrM3fsex16$itPD3$>8kze0=%v*x5{CkuZI^I8H=b| zqpD_x8(lkw=uE#EO>l&RNOgbs)6wtJMy!HwtZ6tL7trvX&`f8imc@88yY?+*?edWt z6Hrq?*i$UFyKf(V_avw@+hbWZZ+(KCUix1{;2#QQNfOh|No5Jk_RU(uH7q%d7LD(V z%;?aGo~|$MOljyr-qgwFL~_(q^e0XXUHyFZ?nCc>hRi=Q6@TNt<#E_0evH1!wd~73 zW;vR2Tl)ulS}NuWQfdZH+^b-;?BnKUNF>4lL$2F_&9;d{qzb*?Q{$V!-gm}Oepk0< zqXt6Rk$FYcrU)Ow+8lT5yUTxpbl`t;so$xwF}doTD_ublF85krYyPk-C`P=$z;bzH*VMTYMFd0*EnDrf*ApR%ThP) z-B5$idsHobcyZo`3QJ2#p!rAxKX%#IBP@89zA-X#$~!2eaBySCGqD%JXTG||u)ICP z{ou+v=2GVngPj1h(=pfT4T!KQNK*_a_tAm*W#vzANu@xww@ga+YVoEinpvWm`rd2m zd-~6=PPyswH09h>9ln+^`FwU$&Q;xK+)PuK2$gRs4_Jk()a_n`Zt~XgUROLVkJ)CI zU#<7MHM1v_0`2ygr8>`$4cwu9GIHJ5r-@_x6hQCp1XVbl|}GX z`EPVqFRNqt`b=h4m&qHqrF+O9p3!GGACrY>naWe8uU?$|^{oX((=N1T&bx6TI{r%P z0apH%$xVZw@v{S^E1R;u7=AX-IRBI1`{P^mJ+)DdZm(Nhx9!ys)qgi;&d#NcpG{G9 z2gALECilAXF9mh!zuut3=~u7{=3vE46r+6Czgx^4==}YPqn;q6v|f2Fa(`*}%CN)fK>s zN+iEgh`igSWh7JBf4P(WrkXJIS(mlq!8d__nUhKaOCydf3E=Fo0cQ<2F5BS!C(}s3 zuO}Hd)G!aNBv~6>Fn0`qcEm^Xk%?2wf=KkJiOEVV6K^(JH?;(8lD;917JzmIRCdV- zpqbteR8$0O;!a{5Ck^-6?y|yPx%e8PkGV~CB9y{esvsn3#gP|u*8TflVhx>(|4m^X zFTA?+{W|Vjl8LD@F=^6%UjF?jA9fvXSU&ndy;*m7bum)&XFR+qRqd>N+7ZMz zAlb+RI*;Rvxn)4JQ&)(Z`U|EE22q7$h&#tZMigKWc(QU?*%=5v4$PjMfQhmJ0u1*P zfSLdfln{x6S3lsC1l}tXo>S>$^>5i%N{$YD%YIS=01{ZRFPS)kssNE>1#f~7XMX6S zV~pNNvx%uzw(#HUyA<~iZCL1X@mJ;LixN%wIkzYCla)9f(zfNc%9;#By>!%A@P+)y zyxB<&ORm0#e{u5M$CabK(#ip5?y+Lt<~EV^ma`ys)BsoTD(mj(G1K~ zye6wR_*aDg$1^j%UgpSU!GaD}ZK}4JW#QUlw%|IcxNN7vkstkos&4A~GJe%yhx>AM z?VUqplUes5imo$HP?R-kPugPMBXpQgv;(Jxq%JC7;q$Yk>W|Fz)D7Kbj|^RS3$9j= zpN|PW!J4V?F*E}{!85~!0(*^U4bdmZPBg9Wx3gaZtPvgpX|Y!>3(*s8cWL%chKVzu zCeiRMAE9{tJXBfq)KANl8*rZra6I^|Ll@)f%w8Ebe80SN#C>3$F@a zyg*OI|J^LuGSVx2;jjDjX8WY|<+i9n!k$Z=5^S;GlwQX3rF|i(YjrEwkE<4S0#qs@ z#6Yv>T*2={^HcQsk>&wrR;Y@Rs?Lw>Fr8N}Fc9Ri29U9efW!UA8ATM3=@Aj5Wlf8@L2-x4~HUkV8vBLYwhiva+~- zeu+=)a@ymk7dK=LJ`mG3@3fOU9j{&Ki23rB!t`gBrbVTPqs~z0RBw`;B^&dj-na=R z0F#^EII9U=R}IIn7uBj1R)00o7cFeA%SE)}+t~tM(AktX)f(wE)rE1rJGYmtfTT`h z;V}r@Y?pIxn{}}O7&1%f67BfXI`0O?&K}mp3=Vj!ZGRE8i(GZ>Jp$=&`0^g_4T#=% z7*E{81r#y=ZRw%E6zbS8TEc(!&;2*XlW7fgfP%uXrZGvTjq{z!w|@ut=dHiU7mG~b z?jZ_Xr$H2T{|*;03EX)Rr`f+iqy#$k-!Q#_SL9MENOm8^be$Ji2zECv*iA-oy@In8 zn>K|Lv@VGYbgXR}ZNx25Qsr%}Dd4^Y1ay0m9j}&v0;3o40b|2t76F6;8g6%8F|t^p z@6E92pKo-aW{ujlWM8`@CwMa~gAu;@`oXQu>b(4jRlN~M;rIbUPH3?%GNu)bx(`N= zX?DMeRUD{Mm)o+6Agq;$#xw#N_NUVzLR>qS^ex->S?S#)8)!I|CNZ@HxwQk{7z#me2TmmUZi=HW|OD(&hTIA`J(sqb+$CMA!iyV zH-zdbho{fX?|yDubE0rYTW?riiS*^o7_sJz5~s?6f8}&os^q2k^veP=hb~eD!XjCgb!&YkdeE98VF1E zg{iHw*170QZF(t8-?t+RL(Lz8nWMHtH~}Fjg3(#+ij7Sln2g_7lOG^@0?nhOrTH@D z_rLl-+;k3S_W^I)y&RQsP4)aLbM4{QqljhOgWP#rms?bH-<)=qz+3{J7p3meDvPVmU>yBUao|jr`VeRT&igFhE+oo z2r%YfQif|o$H<>O-(^3wp8~B3&#B00QosenoYaXsIP0t!#(;D=AWx@Iv35c^#OBf_ z5@6yh?2VCvs%Picc50-98l@eIS81Eq?LFw~$VtOHqlumYU>~sf{RR2sCUo(>BI*o= zvvj>#hNZa@YYuKxV_KOQk}9?Nw2`VyEG_qjS81Av(*7iCX0%n7Os} zNm;?(V0Gp(yKP)pBnb!=R*40a4fMD)=@H}FOz7`JDWcP@0(s^ z`+T7`AvIsA(<%bgQEP=9U^A<6=l$qh2`(-LB$!rUBBa^A6t!k@9ud`+FWuN2#==)g z+TQzDu(t|Nln#XHNU;-v@r5afRx0i_1RDa8LPy%{lTSCB|S_X{>cCKf37L1jX<3^yy^E$&sTaF4?z$~PkjG}?XPBL7* zo9QIN=)m3CR<1-BtRP=Z#_UZfa)p>^W`3999slHgp{q{UMMYFgqQ#`(4}vZ43L7ui zg!g#`+rap$D=4*7s88Ef{i$soDCpA&%R$x{E;tTK27O3B!0o0uW{Oiv@LC6AoL)vZ zF5aYaRRlHauy5$Wmsheom zM4Y`6J*<411Z56;cV_q$XUPzUH&o(fY{A7sLfj1^bG+xfcYw3O{l~$0EqP!HrJ&z( z8z=m@^9CcEYL2Xk?;;~q0$w-qW+mQ!Xg?^VfI(@j(t7;488X28`ZmO$?afg%%$jYD z!e}Kd9>@b|=5AabPX9|$E^m|$q~U<{5987*4ItJ!lWmy>Fp5Eid3N=5%BL0-Lb{Bs zQT7bg<9Pi7br7C02G8i6ou#z4xpweqi1t|$(Wm^;iTj!fW(u%x9+_K3hwC)qcIf-m zd_|y2Avq^gXl3Sgi%=kMtT%ToUzTdkyRVcgds%>1>e?|$0)#w=r@Hf}C`wsX0JD3C zFJ~aEli-a4yeitojqeTJ{nxNJ%&tlAf0!MxPIAEMgT6`6tI|t+S#~(DtvADypX->Z zC|#a=vUkeBve-_#S%#KD$e7ZTj2EP`9#U?zO>^rCUw(}04n-M!nQAN521fgF>^YwB zT*Iw4hT6xyQo6y?c_*x7?Nv$NwDyPcCf`)KNoxM;xzLAvcq29)+ zQ65c*f$Y;4iCck??!Rb~>4_x#%50J@n~Xp-n1XIA@P=N&#iNbPVtgIVh*nuy1{|?~ z31_`q{{u#xCBvy!K-UuH$~-Ys!|a~D7xPUpK@*}0q~sjE_BwTo%UI`!=izWE+;gVz zm1=PFkHmB1`s8C15;{(F>YyxD?~G#GNF%h5s8@IHkahRwe+tt(<;?Mga|4s(xjv~e z?2EbHV~bvhJH6q9=5wMIz}T%X7}RI3rJWd+skG%td|6MUvhuX@F=_Je`I#n7?DG+& zzIUYWsq_F9;!c7GE~dUUA;f|s-I{D2EH{LVXB%ycHuYu>`ZB7##es#=Vee%rB}Jx5 zbv}j=N*x7515ktmNKc^@v?sB|%KLxk&?W9yVeHfa0rPT19jrKdnBByy*WD1K-+DS?+_je8Aq9Tk# z2#Q;m&*-F!$}UCU*SIfeWF>cEy7S|=d1GE`FqSqBsfn*5!@5Xy)jLWfNzbn0M@i3!48WW;lT-~UV1+rVh0m;U3qXT9t`Z?vA2kLG7P@)!tKer35oIf?=ju#@@+!U~&;8+y+P*3nV9h<5IX~wdPH=fz=v^ zsENz@_%9FsSnGekV;vY@?nMtwjEMHhJS4W2D6|SfoK#%bu_%Woq+2s>F<{>c48}8~ zb)Uw(1fo+(VI#?ZIcdBZm8uIO4$RuIgQ7?FFRYYfD;}ZHA{Gj@Yh>H-8-GlEar$Ex zHWj)CT-gC2x)&@-*RR{CKf0)u(HBC>KCcshD>uTtV5P(k$@PCSz8`1li_k{m$>)FF znFXdNl`_^y0HU>*h3xvyuq{p2eH0tGqpY@vYwa49cwh=1XdNOC%k;`!vTaUMSn=Nac?!Fww?VmYN3FEjM^s3+4_=JtgywkyR_5|;0WhT8Zh|7p&f+$%+DTbC ziDSTps8{sy;uZQ6KTBTiwrCjs3b&Ohs~%g~5Ro@PK+Na?017=j%BE{yDHFQ|)q z@8uE&B)|f`rfENzew>sd_rjs;uh}qrX#aqAgEnX9w3?gJck^KobK|n zMbq6r_LSS0hLg70*lv$71=?%Fp7aBDZuQ`XnvifZ(XRG+tb4Nt0=~hw=154ps6gpm z7CN}wO48et#s%SU{il+zAteaZk1H>q#nwuY&9Eq|3M6w?$<$ng;Z zOv)LcepGfCT(FZ?7I#ggUqv;baXUrO9u8m@FR|?@DaskatgJ+1*{rI?)h_TLWLh#2 zMn-T^lJHXUxFGe=_Pfb5#bnpsig}}ssP#9l{~Wol=VC#W`@1CtWa_Q0=cSuYFNu#* zzZpsqfOG(2ra7)fbs-Nbemt=L%n6-z^T2_E=9DgClEyp&Tl3cK|k?NZyt=b_nV4WWC%QcLnkg(zMa%5)YB*Bc zdi2ANimtuqBa;dS%0zj`Jr5$MJac!1u5_vPkXu*P`$TE}k;J$yH-r~slK0-&ZfCd8+xhWFg^%jb|3x6rHlr@6EIRc0J{*$4ne^ zzjZvkf>9?O)^MWS-p%16R$hQa3%cxzP7IFn422vuLx>BnpD3ibo!L;_1I#f8F9Q;? zfy+gHBD9JWMw)v?Q~Y%3pImc;g2OwD1yj#noXK}y4=Y0shX`lGQa}^&<%pa`nPHKe zlw%tjdk0pAPwFDsS8zsJD$Y+$Tx4!{+3KhPZREH#@R6`6lP9C?MFqi>l$Xf}-6I!A zmnEmt#|GR5q8UP4?w}k_p%$Gkq|(f~A$H{elzHVGn%sK1Gy&6~w6eQ-H;xnlQeN@5tF^0^=0vaXEt8qLM*!fRq$O*@|?gAunWM&lxqvv9yiGY$xzWbM9H(Ra|( zvp$_+l20&Dh?tMjg*^mKN5v6pM859_+G9xjF?|E`s&?ZrYh*SMjuVGP6bhIKpwn4r zxmy46jgb6X$IlQc5_!~5dS>pecEKide$2YBYO|=@O`r3%yd|dItgkolN?$^f~1 z&a(Oq7!TBj6y6t(6TH&Wl~I>e_M8tVGj1nfAlEp!>QfaXO0P-Yy}Mp4ryR>WecT@$ z=V#n%_iU4yK#{U<>YnR7vrZm*o6x<&AtQe1kO2*z^BLM^RIm4*zGEy%_q&)>Rz zX(Feq;H-@%w7XstvBChxh(YdF$;_BX7E}(b;O$a^uzkx8Zon0?`vm^PsTvQ7e-1{9 z-o8UaXNeD%nUGzWE;7n_HP4+d1zKun$e$*%t3=SUKYJPy-K*OFre=l?JYZnBp6)B* z8;$7n(K5j{0^+SP6)PbA_$5~U2W4bGIEq+%Bo22cGkJH(kphkLmsgWfVyu0b@FW*W31EriH5diSJScc zBRKO3Wt9kJG<*)^oxyCvAi&y6`$4rtZF;Ady<`O7Y3WX72LUbd0DH;}Bf=^T+F@B*-DheGyNTQ3?cP>9}z zWV8@h)*}_1s_RhQMiF|pI!#h!0_PvRcYU>+$Q|4a?ojDs#o(!~SAJ#h7{MW*&y z@tT|sS%;oBo?f3{^}R7=>A4Afr#t_6^oF`A-Js`; zO^O=goCwD&vQomL%BQ^jstK12gnlJ5;Ebil^YI6o!;w7Ntgg?8g%rBJ&Hns*V?Nup znOPeq7cDlNQLwMGGRA3C`ECcI5V!NyiWNM!r%xfLYD&t#kG$$S{-C?zW#7ki9C$xl z+uF>MmXe`$4lVPBscD_D7t&WU z*O-;xWnwm}U=GcflY03JC{7W1^${kNlyU4Bw}e8UOk{KnGtOWFVLQi+Gg=7dg8=6l zCVblk?<2(NAbN&E+5mQ{NpL8@``g6Q%h1^tt=a_89&v~O!V^WOCy7xIcH?3>#2G7# zn_~KY`?1UEu8$I3mtX6fnGP(qtcP&CKn`%}mX1+O&^8oKm$fIhKiV@s6m>rksca8T zmeFJu_AOw(3lz!8BHCkdwj8F-&L4&>84S>=ta;-&veUnpU&b35wq6McB}2UNz=ReP z%B~TNnx+-VLiV`O7S}GNiNE|;@XZP+m4dQe`yDfpjsVOmM06Op$JN(^!4<@y4&&e^ z^ad^p1)C1=XhoKNT!6&N?sb^ZuPM+`vsj!YrUGzsA)4&4+%A*eqP>3+hDfW32vlHOu0IN?2H~}0vNpP-U{PbgjLKo*Y zCSsGP*Na4+z$pNy;{v;rTFFjizXZ0dXS3;Z@0)9WT2WA8 zrmw!J{*3(RjmW|Kzre#knYkvcJ{<@JL!SjEjwOQOM9=A2!SgAg9~ZZQ@>4GE6&G&+ zCyFj8e2G<5MN(aS)EQaS@pgs!XX%-;Cnz@Y$+e?IM@k47{w652DxYFrA$|MV{N_8i zZuCS?0wrbA~fnHi-@lRkyBYJ7-09eHbYQ1>&;EXJu!_tG9OIvX>u; z254~2GZU&E60LE;w+Y6T06#VmH$Ehln|SxV>-9%J4So$Q8qdwP5 zLm?&YeYFH5d?#8i6E4R>LWSxAoTH<9ms9ANo>5M4zcw-Q@DIA_$*K#5+PO+XxL_z` zGlr&O@URPbKnPr>(7QmSb0}0b6dL|KUI0K<7LZFIZ9_4CdSkDF;kvS@4vk2Cao=m5 zAr~Bwg8I4W*86Hg+N+nvd>)F^qS)E+urH9LLVy26#^}!5$m`87^*dq4g&5 zC^g{##c!ebYp4l(0&puZt@l?(Ie_fE$li=k`;3TK-N6dj%AieuyLtay$`H9eLXQ)u z>#8w419ASI+k?4il;n0me1qUXGj|Hhr2E=wt&7#oKnXrf9EZ#c=JvWo1saaih3xG5 z&tT0@^*h*)=vytpSwyh|6u%Tl-#pLuV`O%af-%0bwWnnf$4M@Pfm(nY>=cSW6Agk<$IDiJHWa=`Chg3q zI!%^+R`w4~&AU~W^ch&5H8DyodjkT1(RshwwE*Q3T%)*&Ufwnzu>Rvovja48AY^BN zRTyAxlLNHqcYY1r9}o-%;3*X7Jn!D9W#2ckPalsOv8KC*q^6(Qf3N1TH9R)StN%*>UcmT>XG4lw{11H>J;uKKAyF-!FFAHy%#du`VIUSaVnotRa_E_7~O{J?b z&J%!tn&2w{{7HcSz#E=J#f@^Hj+;O{l@nIIaPu8&VS6oKZlW{U^YEQcR@5k*hiqoghi<|8^p!04=~h}c*a9`4Img1@Jn1AwUSYfGif&C z>6J9^J#sTEayhD4_EUSn_>y(LN=|>TlVjy2vX4lDh{a9!$DFQ&@_TM*Qq$pG-qGkHF_S-phIEPqqVGiOU zI(GIX{N!I*NbpWm*C?!vF7QcSqs%=58C~D(j1MUw=Qa>5t0^RKD`X5e0wH%RTqw znn>VkL|rtAap+*D^ZsU&R4WT94gA07%VFXNOd;tTLbd~ggehc&iPdU4eAdONAl~1E zvnoumALG`ZQ4v=_w5Vv~Mr&;dh=`*RQ9)@N6ctBp zRJ3R{-{1Wn5B}i)T<$LK&-?X!d69V%c?4lB2HgOHTUM97Ke*zJnv9mm$&J_#3DAhe^(v2vym5JF z1P29-hd}H@0SS?T5B0K(?8otzwU^jX2bg+6MomSho-(Wtn{zLO9D4B8PgF2|6(Q~py> zeFTr*O?kP6I-9Le63A%&g2jpgYK0xcU|hD6&oEg`R-p7B%h5&8X;<0!YP2NFPg>mY z@5B39PCw!wt6WYAG>*rV@A7LLM#J9uC@+0su!W7D!bYEe3E0uf`xEvRY{FYMY(dig z6~Gg0`dt7M^54!?Z193E{zxYbaLW>8``2pllSyeHm)L-qYAl4t$PA0I`~G=Mhk)2% z!M{UVFB(q|0w|qKXA)`i^;nq^eLO>(*c8hkDIdWf%6vF|q}y*V!)i%a%g9>U;8`QI zmpG3bMSH={Ewbrh_ECI?z-zMQ7Ft%}V+qTsb|w}wzRuEILdpgHfalzdlBxcuS8tB( zMI$3#x?#xz6DJy(2@h`P#0>E5JUce zH3GP&7;hBV6v$uy9zmEak_Thb^)j0bZoN9%BoK!=L(m&hr`>^%X$AhrWCE{O;BB!;^D=aEflZ zin=bIWmD`yJFeiu3mxX6g%+fsu)rU2e$bvgyJUoMrNzEeX!j8CxB@zF@*eYH+uK@a zdmlsatwvHbLK0~3)0f)?$^1_N_E{DxYt#vjbNlK|8WPyEv%u39C;Ki}mUrnT8I^~4 z*2wT0J=|+tMF$xn$CCdOW*>x}{HO4m0MG2pb7^|Y*l;MY$=jaZLZ>Sz9TrG^v^%B( zy{+1*3kdFd466lB-49*_@D17%eEcxfhWo51Hd|nh#r|O_&_GV?d5XTM{-kcZ)yBT| z{9?j3(w%qCP(0wb(E_g@1Cer^JKOEgQS@fy?9-B%FuL}OzUt%=Kb;O^VY_SK_30G>kERxoV+HJ6K*ZP)g7CEn7YH-Kl5K zrTwROh;MOlZ(&YYvq%8}|d zz3gqRDe^>pz?@b|eN-^1v3>Rtf$hv6ZN-fv>1I(ygH!6@S*h7PE5#_QIT*EKVC<$? zW7ib1Y%MnZ2p}B8LUgGUPk`=FlG@sI5q=4kf%67(^JL3NH(^W&L0~olN=je{Key*rq+9ICMz>(w<_CjHI-(x4q;*O9f zpZl6CXT0uMc;6}T(W!MS+n6Np>m>AZ8w48oV&6v-BD~MatowHzAGV#Na)t{qb=W-l zE=qho2bvSxpzp#nbP|(@5q?{jz^Wx0wt~31S z!arAaq-(va<4xqq3O%~%I$sZlx@iRa6EA4yONC@74n(jE8WdnD3Cwz|$A*z~O~3}@ znj2Tj$J3MC*QtdP_8KObM2#|=6YSq!ktcZua_S2iIG0>8Ic~i`%$mVIkU$BtnJth< ziLMLpmV?xsE&@$T*%+r#_QHB-+p z5i-oW8$NkVZ>&3Wh^|bP#0Zt^3v}tBZp9_ukn2d%N4!gY?zsel&(!h?N@1i9nfYbm zDAs>pXXreCZZe$-$GQ78Vnr^G+i!AK$`#m|E<6S3d*JC!?jgBfCR-5_rez<(u6@8& z)S4=vuY2Bt_bhYYQ<*u~rogA($yX$zw{Tvn{n}*_q=2<`%bd8Rt4_#kQu`snI~Yci z0-DQw`m9Ha8WLTUwRtyOnkVj6vfBhdk^=Xah*E=w4>zO+*?u}OH+53aneOCpB!rP1 z(jj~HXXv|y&ooeGviCvP`AuEDjy{cmy&Jl0Rwr#+>h_4+6Uj-}CI}$czokjhd)~84 z+kF?_LL|PdGwA727is%=ZxossTfGHCzhv!=@yP>AVp^U5NXJI9TkvbHK6TmwsHWJg z)$ZF`sTTe6h!Uy(R)>r^Q9Jq@C@p-lw(KgM&njF403D z&#vntHiYGuffr+Cp!fJF?gwNpZm|GkLC*z(7XrAaljwNCNGdldyE^i4>q;Qv96>Kf zjxIb4i~voCU{ew9ub)ydTR#|aU)L)VY%U4hJ&i_e;)|O^auD;G!*z7A1CHlECCIfH z_r>Le4QJlitpUXDuZ^d9TSiHuP7(P4^Qh}>W5&mG$`r&GXD?wx9`vpqqxhLE!ty7^NO8wO zO$vfQq=3t8qeM5IX`bI%95q~IuajXN+pRk{pYLw+fQ8^`0f6t2obuqM5VrZM0t6sV z+D>#BdkkEO3QBh_Lt5DSDJ`{cG3=~<0vi{IZm@te4Pwd_ppr9~cXrWx_bmH>>KoH$ zuW4l^c)o!4T&{hJhmwsDtW)w*Kmg#bf&g75r{^WKkiPSMZV0v)pKXf-f$=Mdg$pq z3a<}@owR@(vJ4LC5ZeERR8FUMj`=*V2tB4(bX&rQxrUVB~_%vxz<+}8F) zceYTIzrl%i3T&;t0xp7IkyaW^kjl0*Qra@yteRF(FZ!=kvNdaZtyI zTD|3&gxlmYbk)bHO*ey|8?L%<@;^`R3Kn;OkozQn`xAyD2lNm>SqvNL73g9o-WSb4 z&mRDD^ANCn+@!6x4o0){Hhu5BKIMEmloTWYDN-ZA7J$?&dko8kgkzS1FeiwMs2Z`h zEueJ#snaF4A2(5D!=)- z!>piOjbi*@68W$d`KxW?^h!fvRQJ%4IE#>J3vUO6*+wN3eU)x<61k8wD$@r?F$vaJ z0<~W4ZIt15IYOqs2g+3m(ZnAQEMVly49>S@Sa%q=>H9>Qr&&*R7h`DWb@*t3Y@cY{ zG8{U&8+XD_42TcLBX+DT~{dgJ%J>6g%c&4oEyxErs?lVS6zfSo~~l?0{~A! zPu%HrIfKTN$D-nK!R*Adn|(hr0-kwKTuOwfFRg8o$Etc_dfNxV}Nto zkZ{vM?dsfF!@1Tx+Zws^s!ZmTnk?B7N3B%s-BwT8FfR5=IPhz1`+F@z-Rkuz>qa!2 zS^c@|ncEN=tn7I2oBJExzQJ0*)yF$ef@%UQSxRS(UK3?aerX;OIXk?sOsqY~a2^b2 z!9Hp}y4$PEBFvabr*L|`s-jnGBFys2l8GUDtqkjSqH|@X6-@@=uIHx>K#310mbqIY zHXEfAL6H;`S)q*+pa%eYSpcdLa0Ck7=+MOhuw)`D9zcINV@mY2x%4~~d;H5vu{8p4 zG7rF59XJB0>;*VYv0QPHyJZ=-s(~B7hJX^7RdRSEkAPBzgFk`306_#`9elA%c!VVu zT&V%O#8SK~YQeNPu6#+o03$ zcnxpW8BmXvm`-P;8-VlnxL?lTGR(L|I@|+${8D-5J^&XfbeIDKsvbDF>d`+`l`f_4&f}&Kt{R8-HHbaH?Wsr(@ZfsEy@x_}cxr zP0L-3IM$NUsnIDYb6(#2+;LB(adAcMyD@fRflo{mK^JHrL%}`#!e7e_LoF1Jr-e>Mk?wyN(vzK*Kar-2bK7%zHL` zp{^j}|B>z8FudpLztnNF-C-LKZqO|lIFC!$;e5<+hk;};lTKMkm(C|;SmApBalk

    je1}K5jo{MhOXkTd!+OZZ52AYG(g*YF-8Uo@2`V?r(CaP0Nd`Kgc?%vU z-hewg4mw(wxmm#{g7;dtLeC%GOuhpG509tsRdm(N_LQead16H zo`NnGD=;y10T(z=zVX2~rwwLZB)@ylerVZ&`fvmW{uDE4gASW1=TnZX_!o^h7h<^e znxe^!L4LW0O<9-4F-OQMj4Pd)#1<=g*f?D$Cj{vt`0Pl&5z3>(D3OUOuH~<&Xt9^= zjlA`Zo|<}=JYd}RZW0%+hlcVo5v7C|d|s*<^pb*n$GYr7uYm&a2OYkEPMB%@|9E}2 z*)7+C>YlI)fL8&~FP%(#1oT0HKY-$Nj69&DhMAUXtnlkec+?r*Zo#btD39pGY8@fA z3g=v&aK}JKW64quye%A-(D7h$>{Y?fg`Jto`?!BWT^NjjjM zx+8K^Uf!~OpDhkVP7&q1_C%33w(pWJC&<>Y zu9xoGU~Kq})G+1v)K)*z-14bMtn5K_fWX^)_7d&vUYbp|_luSMUH|j8;XVJ;(Y_d| zR-WOVnfk;)(Mk7wH@2SLyX1|I`dTWn@1ZWb+4|)=wXc)*Oht9Xs)8&-|AL!Kx zo+|MFv9XO-Hs8qzx})g|D>z!NU|N7BPX50!Aaqvw;WOw*C*5O91g1zo1Hh5OHVpiq z1E_6CQCv*f&yScpg-Ms;9?TEUw9#=#9nuz@-)7ynFAPiEkV-gx{_wxh!b}~Zt|%kuc0P_W zoJ{%WU-)(dK8t>Mco5!b`2Q21Xmo08&c}IUr^*j3V}@nWNj@^%u+=xthzUg*QbW~} z^SA;jgyIS4>3z=sv)(YzaMbr1IFo_9t*6B#Qtlh617^)5UH*M`d9^r%J31msaw+zR6f=dFedp`KDTa-eNvBSH%sE|wR~b(I@%R^w zF6Xe{M&--iaZZ+-R|D=-ruXZol7jj!q{sEhqt4^zwhwMb^ZtJSw>S*wO>Igtwsk3uhPdGxK zDV|ueLQnhJPpitM+_CQ4ewB953|3fww&|(PRk2w}^9+rzJ{-5;6Qjg}-@zuDtu}Io zk(?s{Sypg31QTHd6_F4<0n9LfPP*kMlkd_EAQNzpHUc~&X1xIWNM{$DnCC|)4#_De zdGH@H$|b~ZRA;xt>cEthZg>N3I(K6JMsOi-2JkqEkh*_qcF2ZmT-E;Ha;TOvBn+$}ta{O`&7-y&vj)Feu&AV*?7ecAC=-=7^B9(Ndnog;udzf2g8f58! zp8Ahj!rw#*KTCduoFiOKVC{j;&Lp}$C6V*n1FK=g%3v#oLSiWzG%GKKowJ2?X6ziE z!{JStd=RH~e6(;m^W`h%Oh%xW0aIv(cd$u2WH^N!O-#cZ0sICAX)c>|ij7LqP6zE_ zm37g+2k^oB4R=X|@}FWaKg-$pQ|xN@8&QId&jhgoU@CqJ<8<^rwj-Fbu@Kt-i?eej;%7vqT|CKg8M!RZuT|NKK=xIgzWjy|C7(L zkV-R+`3vnIo-l_$W3C?Hbmh3waSR=7yQBi}bQjP7{N@Oz3lcQ!Xj~t8z`A*d9y(wq zN&t+949;Flcw?b`F;k=PG+saXzF_mQ+q5T^C9+jJ zc%za0xm;NPIXp*)t2NvGVz#@2o}n0Ue!6QKKg`)%)c>rT5B~DVEDOQbQg&MP4Hh49 zsP8>v>l4AVN0wLNXQ|F7(dX?FAI;FP{&>xb{^BoG?Ed%nti$PgpMT8cd2-w`Z_4|l zyYh_K)nbQep+ggZi;+RslI3aMOe8^G->OTjUr==!_xEaCIveLsN57}R?o$2r26%F?Jufh%|Y6PkVaQ+9v$siPGeGxgWcEIVD!{xd<&f_SG+@7_3m%B9m6$T9y}4g7W6 zHZb+@I^1O1#>4aUPlKFaKMvgYoY*Lvy82O!_J8j#R7&1TX`ke@19zwg4$|^nMvnyk z9CV6WQnlCh5aqFz`nTn*MgSdP0JZZk_*!2sL;-4Gt}hSrreBO;|5qf{*GM%Pmamsv z;YW08i=2S!>ko~T5BfdWHv-fn{UY7_-Lm_>!0^Nf+8N;-9xh!9<$Qu^t(4#XeLk!U z>9=0g2%2w_T|*Uk)XY(8ROwb|tseOA#d*gr(9r^}KE3~6ua{m%z}E;k0rZTurLN4N zwAQ#B*WJJa=^%T~p~A0-)gpu{>|Dcgn|KWzin!d6YrE-tg3*?AtI4|| zV_v=Y{%U4b!=64mzo<@g+faXWCOJ5R9N23AoN+FxXdw4;!!ccJQh8JOJhcYX%yru` zQeMEVYdKnB@0fjlPntJhmQ*riEhKOXA55h0Ppr$1{B5u?@Z+s8cg}%(je*LZtbONK z-RsX^w7yd8Lvvh=5^qQFc-Po0eO_d~#Px>**i{Ws%3$k^l_bT!v}JdKq(5ivCOz*; zBGO&LkV&hketPLW{hHky0@8mezQq!r^{?CaY1#dzyvxVm-jgNN5a)HTfbE(C=P^T< z9NUrJ9O~P6;?0enho^72-%MG2RhZhyo!Bp`mgv*5cEN_e?hso9IuYAN@VbqzMC!Rh zB{{+@=+q@P3zYSIqd-aI^X^rWC;Cj}+23HvPTGbcBgS)+1x?45A$qk(21~+Zdp(sj z3S5h^wCPr<{mzIFtmike_H=eX_oo95{lSgv*ER+xb}ijCJ-P}20)OVr1A^~i*+4|s zBz9-iL*7b~*b{o*moDjA+L(x4IQOO?v89J0@?A6L7|$&~TEAd*OIy9*Aw0IPH~dpAKz)eKYC)oWIKPyzgH>1UDBC7XRe6{p*^S9;JJ35)N%2%+s6)V#g7TTVf}= zB{n$Z(+GFyYfjUu&*iDO1)Dkb!JQ^?U7IZ=H$-9+di+kF@5{D;g=pC+xkR| zLUNU;n%`7~)RneYX-@9H8zdoZNzp1vVruydm$HuX-iFje`<5Tt9(-o(F9i`lRu|OH z>%tP*Qq&_sJE|}+*P?H$o6cTJp5}TcvM%auQ4fQTALTS=#@{@(YqqHPs8f=wMW@^u z&ns>tB&*qK-mz4SYudi_V+YqCy7EfECv~-=l~9T?-O;9-g)TgHQy!jmMeAB0rC;{2 znm@=5UctAoy;V*3vE(tXx^`Euh~9PqGBT1`-Zk8g&V0LQb5_4D(2$i_tl41S9*7orvu*A5~r7|T@%}^w7R6#5x-mS(t1oT%Wg-`ae249xUjkW z=wwP<`^Ef#k!?daZ+LAx;ZnQLN|-6BrYCAeYodM9Uu1@~2@xoy+ksb4}oT*vZk?)i-*&QYTIZ4=?4#KBB zhZ>ohX>~rdS$&w@$0WzNI5r42&1HQsS4^9FGbIxcyVCXkfuN-_<3k?QHq;}nHOWKX z4;;QSnGnxyt(j`$JlkcpR%MpGqN+L*qfeNqrF^l+TXSXHh{9{b{AV8b`S0Lbt~WUp z=XN*F3s=WZbnV>IGuP{DQ)1M_eR6KE+RbWC{1;uPInVSVaXL*%EKeWEd**VhtaplK zbjLPB^;Uv-D3Wy1P`FLijE@|O8+(68we^`BX|swLX23+|0`%n;rB{!+CKq|;^s!A6 z_RgYRoa{roRCYan%{{wC>yOy~Y{W-z%Cjd~$sCT)O`OvvMO1rt>>DF3pLC}77Kz=4`!sy^s7Ikh#o)fL@@_*aN`GCoC= zfm8%^JE9AOC-vyil+oq%*-1&!vC&Dkdxj6ATH48B#8ydsqmA0V^5;RZ3&{`6N~aW_ z61r(Za3N+z%>O6Vlhuw+bjbd#$nhMk2{STxH! z&w8PooRO2so-43t{rkmj7`yB8%)QES zItk8}lO*gBfhNl0?KV1oW5x&Y&YI%XzPDvb(`vt+*kWn;A@#FeWcz5rn|86st1LX{ zY**dJs~CIOTZoAVKtfT+F4Qrc&wdXc8A!;KzY)M}Ic~Vxhp}*?hP!^Qw#8GqbK9)H zSA2!b&ZbjR$%93V~UBb3}6W_O*{kvICRY}SD z!E^FnM4s=u_2YW6{a?TZU!Qd1$`Q!cLEc4Kp49^f(IK+>I>vWWfc{ueaJLvS-Y_AiX{za{Mu&|zJ>R)KUgc6gR1J5ACJ(e*Ms2SK5l=^uk<0Wa*%0Lz+rp#R%E zv~823(UUMGauxGDK!{~;x7%h_u}09PXn7YZXXAP8d6WWqx2w*eSt(iP^lre4h~Jtnu2=x%2&dJO6pE#b`4Vc27zI)_!+5hDm_7`k!Rso(f!3)_+nk$n{?;iF6US8+d4a|g*khO3r8acKkOSRgFOW}g+knXqxQJL1LCC~5?Wr$|QMOt_j(YlXlfRYj7!S;wE zWM>h)^R#99hN>D#uv!#@#VAP0MTAv`<8%Ww=doMhkXxrscV)eudoa&kJOuIJm}sps zRI6kd5C19*BTYsGYJu7{LE<9(avt`*OYe6<%^RI2VUxSqaB3*^Qa{l>r%f4RJd}n; zl-UZjmV|6XP~8(o(1bM@{S2WO{V<~sBbTc~Ix&%2hk^ExaI}PIobaKzGL#3na}_ie zINApU$trJX)i}6@Vz?Nghwvz5hZ2_36ELfjD@6)j7S3)><5?iQu!0`^0hGS<= zR%WC`B=8*Pm7gl+ogI@Al7|p_%=M%u=cabovnPnfMOaLJ{Glz7zftYaR@j+A=1st! z^}`rQq}@Y^eF#wyd0wiv6~akVgFqpoth z@eazxdf6GhFvjg@Wt4s2aV+`h`IJleco+!g#;HgCjextGCSBAToXFKqbtL9Jc z{GS$T`nW%5J|fl;R$4sH#bPJn&LvZ?(^%bMgC?!_josQ);$&4qMg|HW*Tm zwkYR)qrGGCw~WRl98nyOXTLcSHfvmk3UVGs#XnzfkgcbQ5%n9l0*ub;3)bVZ$JJtk zdgf5(wg;KpFQBdJSj9p0f^n=Y5|>qkn?uLW1ei0M)F~|WbQU(&2yz5UcGJ*wL@5}@ z_`?vFhYmVHZj6FyK!_~h!61Uk0`{&#`r;9=Y}Ty<(;+UxWFyCokQ*1U2H$a$gRV_8 z5C5vbj?cOxNwX8&5y=kQwhYaf`G!zF^xNi60{e~A$y1eo+^&qN_C3z2{XS`Dv+T42 z2Tz^xZ7)*u^>>$^0j%@Sh3famhGQKHy6cDs!eFpfq3NaLR!q7Vvv9v@yW*R`pNxBo zvMP&Xu5cPX>+Ut-(TzYYK;_k;e~fVeSEkwC**GhmP@P{{np|00Mp!M!CKkEbvq1-= z36*=N$T0hFWBlsWr-~#g*56;cvL8Qi4nypsb!w#vr?^rPg@~v3RLtn9*tQmzWYCl# z8ni)$Xx09+2_Y0+tWUm3&_b&PYFWfUk}k+atL7F3J#jse5pv*SaW$8xZi>OC2vq6r zikep26oXY-XeE5hs$5-!AsCd4Ymt<*2~qbCBhRiRZgKiq+O>2PDFsdK%BhAjY$=QU zsve3)6a=8VJy2~*QDxNDd_;R1*s^E@n`jl(?L&ykTo*auYPf!RNQr)i!lH_mMHK=- z9cPvVaXq6aRUzoAmoBtOrec9R$j8|XD>82c{C@8WSKHeL(7}t!1UTglf?uAzGe2!7 zSY`Jy>Oi}qQY_fJaDXuHeq|}JD85;EjrZLV3L# zOB%-TQ6Jf?uoUIDO}Olm^%Km1i@rHn|NA+g(QU(-ic^r{SL__y!$&hG$I`cDZT}UU zNME+?`J*Z9=P4|dDO|?QdcN?Q+QF#UI!)2=TolGZV8ipi357$mH7EN4b_QN>hb2?+#?`)Asv@7*3+B>1YjONSRnG6id%7Ek0)?jwg-;B^Xr_0hVKIyEe zXqsY#(xs=rlPgLP^(q6t2z_!1um_J+#|vP_vFo-kPZS9{wG_c8wPSgmcvg`=dvXRGwopVUTve2(%44f)wP#n`?g^_1(=k%z}|`zIR|R`cSA0+k}dmO z|9b8Dc+m-da=SJ+XzQB$E53}wolxO%SC#+@k5Nr4f<$aAsz;0Zp!RAcbR4&^uOfrX zT_+$!4k@GB6VkU5)=3xPv`|DR;1Y9atT^-et4qaO)L9eeA)h)vmm=1xvwI+SGLB&t zUfS@+nz;X*OywZM?>UYC)o5#sdT5sxINoM$W55b*`mU#is9gt;VW0{LD zK6iX+3Z0Lz9xY>e0jx{vg{#-(crFVYeB==Goc-6MAldt^t5n(76vy=>tN0E&-eNYr zXkC!zroR<5N4>3x;A3apV0DV?uMD|WX?W(FzXik95oXrxUJm#6-8{f$hn=}k%s(Gg z^tvi>1j5Rn;Mg*F=p%(Weq6r@vLY2D0Q_)wbH3*z<4wH=e`tuP+*TdM(`{P{ls$Q~ z)LZs_6xd-??KOB*8VW|Pd|X+w(gY1ytr z8)r^1=wOF2_B{JMnCHI7PRt*AN=dFR!s?1flDTA`8FVCA*S!9}uC0&WD5t1P1RXu^ zg>*y?XFsA68O~uv+mah2lB{ z==X20PDZCKyRgN!_}beWvps&>`gY_(&ga{6cTZVxB4$@_iA$SEBbt`FE15knq39wv zyV*bO5I-kTD__@Fxzl;Etl$fWD?K78#Egaqlk6i4|NcwTIPMJC^vEuUk~c57~U zB-Yl(mx&4XIV7R-P8aE1;UKAey*w-oOSkB@eRt8JyAIq4UBsD$%QCGLcak6yU<}14!g-* zPjHi@6MZ@3j_4RPxl`|bOWuQb&Xx;%J#yAFalZBKXF%7Tyn#ER4P*CPr~fy0WKQju zt4HRYUr35qRtqzVpuWl3M{jB;@RwI(wSYfxWCUoyBqi9 zN@QKctCC%J%l>=&@%QPUzJ1(Tb}#YltQp}gjGJ{IzhN9^z2h~f@$boZJ&S##cDNh; zMBbeiZC2Z-aJ0typJ2}{ivYvmOV>Enaq@8}i+@JLA7wYv;|4hS@P-;bMD)!lp(~P7 zR*c;%C}Ms{=DeHyI4kl&`;ncs->$}amdBo$$B%nbgrUuu-lD*9zIL3u?W5rxsZF05 z_!0K3Zp-U-EN3<(dFt4;ed}y1zk$898Mh4;9vfo9s{_{V^y%l;S-f%0w|Rrc-1RBQ z-X-+KJjIgxr83+G^Mb~q$r}J>cisW;*B_8E%NYW1h!_5aX?>z(w zKds_D8Rz}(B=W9VKNjJXrGBq~>{kn1^E1?H6sRotw{ujO|h&o!8+nVaN$2`(X>7JL?`LTjp5&#j88 zXC?+Z3A*?De9$Jki?ACaQsLY^)A*rcuW!kO$PXaR?F-rSZ=^YP;=<2dwPh+>JsV2B z(M%B_Ak*~^d8e(MpC&i3 ztq}x7kM#t)ww@JKgdz$u+C>3-CPMbfmKu*Q`B-~H95GM;dZhSb=sE;vT(fbWPW1Vz z1|Os!B_YCYrn${u>@KpeW>w7QzOP*-!7#_zNi`k!D$=Y%hvp%*Kh`^Co(`g&9G3*f zbXCtDQo7m7Dz6>Nt1F%*c6%+*iVi4n<+VR~e(FrQkj*@_#aQI^-I7>4Sx@6%y|wuB z>&gp5^*cN~&K32~BE;E_(yiy=y}ZW{<7^8IieUQn#_~i2Y7kx6|t$eJJWwJe(x7HQ)0Ll6BGVNRy`bouh5!wt1W!N#t1G z*E(c9rIhGQe>aYLP78Xr0i+Z?91-7;a3BUdVV93pS3%EOMpc?%h}f+Aa4JVpu&I^c!oJZ5BB_G{TX(Jj%^qyFFO%35f+N zG+cLM#>ASgq8Gj%R(ZY1XNvt(9t@Dj|G=w1UfD}I_Ve(hIFBuTNI?KXa25!6p2e`p z;~181CY~?YOJD3w9u9t)6m8ImC-`=k(p9eLMk%i`j}Y75o%L8+8Byz9BL;-u3QjV@ zQ%7q~wTQB&APE048~R0zsOiRh=U;|u!!*7%TT%3(Nnbl<_7OJ+Is*5^0CuR9(1)6$ zu!z2YNZqh+R!ZDN3r#iI!^YQB@@>~J@&A)n#+#n)*wney^Y`%^(@x|SM!s3<@&-xd zpXj02UN9s&umszu)D;&s!_@xDIIX`~R&-EiDo6 zdo4RAdMFGO5sU%Uh#@g5nk7@t~0*f+qp4IMd6@tJM2&7@ZT|>kIlVNoGz_+ z^4N-dS?3QlovGV4|7k#FnaX>Jt{J&|kCfFUD&qE+EN{!>T+Jg)xd>9y`4UDZB4wUs zS7m+^I{&O!qXXF0DMQ)zLrZ%+#xK`}NW!;$@Kr3~nw>Y#5parSnD?yrok_`#Ywt2{ zE`_$aiR1I_++ZObZ8;vyvMKTIS~QFsjHAjC?CH(3&K!}0Ls_F8kHY)+on9&lk^_`n zaW^SFM){v~6ywupvg5NsQc62!dB5qEK^vVT7TL>qR9E&H-2QMVp(6RKkB`sAh$?iO zXIy~m&C#YxOs=)i@#^7&TGtE#x0xI{aaGMsqZn zocEY~J&Jmb-i!u zt1q1{4Pbtc9+9%@HWv^kD!K@F->Jw~oiJWa#yJ6W32|1@=$_H9<5}dUO`oesz5u#+ z%cyxqSZZT+w!wPWHqt~fGT6)~&(vQb;Z@8**$e&nnCa>e(`sLy}xzHnYF^0y@7r6gp^r~rPiYCq z9O3j3rYrrjOegQCnuOJU=9$ygEhk&N+qA=TzDXqIEpDDlzxk%fWRW?r*2W2eB9l$N z-yZ10Q!k`ofwBrD8Q5_s+uL`~2N;7-SoO$h2jpdn^LZUtNmU|9%A25{ft^&QmS>ogq? zaw$+aa#H=pY0)>O0mSDc`v>s$3zA>7vWi8S5biUIsM*sr z$_uB&9!A8;VxIGdnLuA`1aAKZ%rY8mVq?2FrcJy6)j4%NQJo67v(_O&uw_(0`2xT) zgm(2|ZjtfZWaemRty*H7ryoh*q+Z_HrjP@ofF@B$$=m?>+`vYta03*u-qf(w zf7htcmB&(flj!YB=oaIO(_U0Wh;au+7daTC&w5 z+Ue`rR=tU>j6GV*?O87A3#dLYhr=NmIW6e%R_ntnK$M z>W<{xi5r3An=n%b$G2v6?Uz6<7Adh-AYCI^oCaVgaqzdh3&Oj7`i$L(f+GM3QX_uy z`XX}$K~}Ngiwnn!gq_|IWWqaUr8}3VGyfJiD|%p=FSZZ4y4y%C5YpZnoyz)b9^blr zVoW~qM}vG!uE{a+Z*irSMz`a`ba!#$L{!g=2mDL%&cQb9KxEowJTSqTsZ&x6e+0_c)S^ z>A3jXIR4sLThR7*pV57nkxY&b-Mw}eCT_gpcs*fun>aez}PGW z^RUleql7JAT6Nv|?XM3NmP~w;559fTUipx$HWMqzWy*5%*p+S_5Vu1VyxtUmhyoCZ zdmCC5TNq+6E!_l7gqC!RmITc53Na|T22((lD4@&~a69b}v+dX)W{Rh|u&+8Kpu@Cm zt$#qNe*i8dKyC^sc@ZRs0u67Mr`mrCb@AILDpD5CJYml#2Sf=lZtdOIa#0dQ_9Ism z$y51_bzf`mo-P5$gjGwH$@rX}ST;6Di1C%0h=MLZ?N3DhLE+?2Lak^xQN8`*(zuH^ zvnSZJ&lVqvXka+ocQ9Wvw2Jxk?zN%D>sd@-u1AyjZW7NJ7;ut zm z?5#SzM?o$E4?I5Yw^)33HfF}V6~>as$0}~MRTRW9SGG~eXBT6RbP?af*fRWbd7*dP zg|S&31AK`}_Uh>iqM%eZa;v@f6IH`3=yuf_c_$WZU2v0{m$Lo(Lq3$gM7!XYM9_@P zUZw+lMwDJ#l<195L++6e`Zi@bh; zLYQma?<-N5+2nI9xK9~aYYM&!1r(IrQkuB3)LZ!<0%RZ9XiJ|QE$R~mSirz_kpDB` z?!0RMA&_%Z!IllV)Pl_03YS`iOW9qo3XSKY-3i3yP9min7y50XiZAbS)yfFzpbt;U zXU*b-ppT(0@7zl7!F>;6`y;O^9_@J9A?+F(^`%WW6^$;P`oACVSD51Vtn)RT+p^=} zsXr7hMTdP2>o|Z1r!glc()6Y6Wzm^PnyD%#0W-(iL(~ZzV?3onMy5O8v8-jr`<}@8 zU4GV{%1zu#n+cR23WF?b|sKQ6hA$GkRF) zFe~|u(74DjXZNArrMWjE)Ch>m&uVC0l<1&l(B==V^M)u=fC%*BbzO&6gO2Ob>DFt4 z^dgBGyb{gsVX{@vu!!l-``~9E(k@p3{kEcN&?R;(?b#qINjpiDJ`&Tw^OQ{`JpTAIecB$E zNyPc;E6OJUtqL1H?%u9Amj=3Qf{h0nVcn!ICrEFxUq$^SozMk1cC|agKjE{#Wa*IBFyZt*=;_X^{ zw|4@5qXvlIguyYUfGx^t)gsqQQSgW;Y(yD+UK#e*6xJ6oaT%IOTjHyQP-WC>3xvcy z@&X`VK&0}5*zu3}zbM=~6!C|<7;KUoqHrA(e8CCnI?!_raJlN?##XpV!NWa^4qLih zN5Erz5jPX^J8xo3lz}7Aj+dsWTuhi+e4Y(}Q~(MQ03%sC8WP3qf8qy~O;<*6mD8(D zvqtdfKr&V;=EF~Cg{kRrp^@t`)I$jIb}F&Sjm?Lq{qLh~#$T^yJj(Zd8#(=%_kq8T zbKFaPla+04<;`)TLnl%DL1Yex4xEehjZxYfuIugokD>DnNV4n0IC}~LZXAItx45@~ z6Sd4$X_}#xxiXwtS`RYZ1I--Sz?qsAnz>pyQ?pWAT2^Rkwy~_Nr-%3D8{awS-1mK* z>;L=xYz*1h=?u1_qWu0*Qo@CMJ@Qpr9)4T6RVl*pQkLVI_Qu$$FJZl~MQfXQgMgQ5FyI{#E{%OgEU82qC zNcd{&5TIWiM3@s zBAz=Fb1NEO1OSMpn1KHhR96mAhcy%B6E@$YV)YQ&RM?X=n6=A39Ae)mfVGYr*@O;R z9{8^Vi(JUjsqxH$4{e1Eh03{+&bg4~+&FMHVF9>7+uYIzsH8`oT2%<9y{|BF5?u)y!VJEcbW`~Yv6mUpc?7NU^zw~_1NXEv zjY|^e8@;a7Upf20DCp9G%Z?AK9uT`@k2QL}RDNJ+9XH=_1%2uQm8>M)5%Vn@a`Co_nuDA1KSKlRWqh`SV?|_j@|$g)~~F_3j8aRlFXG^;O+5 zQFSJ21crU%Bzl88G_Q%*vfns%7n~<&Ez@P=5PwPF(`KY;U3pSxG{?FpKN^ghb{{Up zcBs-S!uPg$m_&>X<*KJn#lc69y{6qe7s14jDkZ%r_r7pS`W}8yJ$Y)zOY7)g_&DjG zzlSeoe^Z-1pIs3$JASO@V2Y2<{~UsRb&9NKHIG-O-8+A5j&@I~NpA{`xv2Tv2mTs)p0SBE9d4#zW(-)|A7eI$2+w|vxF{$bUDWr*Kur^l9)`Je91$6%xK>lH2>UnS*kx6die_itqv z8of4XC9(PT2oz6|4XQu!XeC3SfGvm>xQiW@A+{o!j_HX0T2vm?yH7xay3Y-A=nEE! znu;&O4xc~f7B@ms91hdR)2~`4wqLuTZG5#WDN0(n60kc-Ev5!-T4{K;xPa2*x#veI*sXo79FCWz{Xn0a)l`RYsf~^c<6zdN z4My&#%>`!NC_6FEhRt@U8S+*?v%X#**TEB$jCmoniZyMJ9KHln?s9bn9@n4^J zTaI%fDc+hh|9f}gSUjX6+wA6urBB8=>2^zBlkb7|#-(ktLH?Qh-cH{-8*H}c0E27X z6sH5Wx;K#>gQ(OMMH~QPO#%g$dl|G2!srGt)|c3OGjoe|v2=Ihbf{N-UnjOp_({1m z*5i`>s4NQu%E>y4L@A}iW!!f+Ey_a(vOt90Fv@2EJV>fhBN09zgNlF#T>~Ux&F{p1QEg!E zgylivO3k#h{-eZxtJszU>hl5jnC)GvqupR)Vb0m;FZv}D2V2m!^@DOlOkN`Mtrag6 ztg3r3{}|-m_RNGlVz&U=H^G8i3YTrZ?4E8t-sblsAkbxwJk+{xyVqbG+{vq^^h$)g zI*&3!lw~j8R2>386QY9+L|7iA!APXRQz2W~4?Gv0&42)S%j1jqxhF0feQ{ZPf8`lS z?rqazo3_D~M2M~cU+T8vR(l(=H#k2A)2w>D`u;6F zoAmyvx*r`eeu1mjVLzA8&&Zv%aX)>u=XXu3;)Z6{#Jf@7xzgBtH^st$5v6$1L*s5M zO$jcXJf#fvTu?X=csZ2#_cw9NUkAa?@1agDZlw|2MlC%wywIQ9DGZX+3O=y6QRO7u zm=%TocS54)HwS;h`kb0OP-af zskXdiL5S1typqD8a*>^X&p&09jECKXPV-6yfHt1_IR8An54FaeaeBj zxiR}<=q$$5T3bGjaXltha-%Ta85Pa|KR}nu-w52${F4b+=_b`FNf=Q3aki2h5Ean9 zMLrt=(fUh;>m?RqbA+50XBymob1}FuHn@sz@P&n4Wnr;Q++T(~MAm?dl7GRFn`EK; z1aMyfn7q2E&B&Z3p{Ly;4J7mvPTC7_>vnI~X9v)vC=4Usr4NlBwUqOMcT!j5`%iJQp)P`!(MSkbs^`T^>{___{Z+|}u8 z%cieHrs_>OYPy5pD^_q?i#Xjy936$(Qjha(Z<3>7ZG0Oaj^XC0*epM6Zy754Bi4Jk z*K4>*CdX9G!Pv?=#)1@GD&fT7l1DvGBEVPGjfu0+ml7@Uen--(>Cr|mVY4$8)uL}M|;;jyf{`5sHS|0i4 zE5aEV*>oAIC(*WAi;U9+ZJYwPT=_qcA^*Z?%IV2PGLyZ&;Ag=~I`E8HhLyEVUk*v+|;H_mX-aFu7*BD;p`a z)>h{)q4Z?4@KJJ)r!n%h`7eH$uOIKco_2Tza6JG04jrY6&v86|8YB3QxsV!`{n6>G z1iQ|3>b;JedhGB97u+7Kw$cWCt-w(Pyy$BMDpPZ`0l;aJm_W_BDn?}|D2AiWguoqL zacafitz_Q-x(uqB0J5r&?qQ(K7)2)@x+KNB-ZB7nvGQtK^D;U@{0$YgX&#phov861 z?!bG>G=kP5R-25u!+|(oQg&e>ZpWgH0<_R^8vi}~`Wm|rCY^AhvrxX&bigAKe}nwR zpZqCluqq@HMg=yp#62w3s@Xo>L925*;BEmV7WB7STGpZ6BRSdf;|(+Xoa#q zBK`?U_5)8?#D3Vos^(PHysg>`BYKbt9^{yYxU&MXN=iXZ?AI8%^w{lPu^I2H8%*?4 z$9_(^nRMh#Z45>m6i!W=5Y4=fyaswMkmGS=`JScwO3*iZPAA zS5oBP2>$R*!FlxNDfUf9GokwiA!$%udMe5yIm$fs$5!1aE7K4!htrPQ474f@O}7Gd zeR*yPI4uW(c*sF0BwC7aq>F){V`T4cQ3L-8}oMu_Jw#1}dIqC6@I^P1_X zP~U`#il*5*b*4IAKkcYS-c31#epBYcjf5zOvIya) z?dB~)n6n%^LabqIsrST;x4Fs}e^Kf)5}}VFI}g8XCw;9+8^bM)0kr57T(TMX=nowTpQel4ihmt`Q0TxBj&mkBns4%1U z6-gIWSit;~QTw6>)o;~XjMl5n{ZL`hG;RS^PXPfvt&55PS9Jis1P8KALbxy~Mi8*aed)y0@sj^J0Qv2umuyEnQ%m0y+& z2@XatFlXw#i2h>2(Bv}F$jrTzsG^K{mQ}Awl z{q4YU3vD_HHAf=oR$|t^@Wp|mZ~hlOdm{?FAzph#iVZ-04>xy)ng;{C5a?GuNSq^W z%_g1pYB;vj4f2E~H)V|%)Z;qDI@Q`uNxS^U1ollPEK^OtxVacTdy7QTfV3HhgLr{s4=vtJVaA;rD;3_ zYKRs!WKMJEGQ3ED7EUTnaPmYVxg`c-nTp6~$lIfL(8a*R7nc z&n$MSSzTWtYkdML7*N$y)Foq_S&;77B1-_s|FnJ#Umx`o zyj>ei5sdTzh>AdrBI|8v)mt;Yx30=~=R&@D*}eP%)Yl=qD^6Wq$O_Ig)QyRo;Cq$* z@K^0P^z$9JjvB}QTbMNxej5spgS=CFvUp=TORc5;4&J*?JIz{2X@^jLP!pk_irw)1SeD7atcLHDUZ^3Jt582(fs6679^-o;v zlisuXyS8fXb7-2rt9LTX+@{6vJuzvp`DT|0|2{vnsA*m2Maj~bq`iI% z?7j!=nu*f7y7$e4#rAs5%PWhZX_1bjveU(^#0Ay?7dOJ#7|Koj&QcCi@eXVc0L8yd zcA$coFLmilpjI)8zyg0C>eM)39x!}+#i=X%r(r|XN^e2yV#D39?AU^;TmSQ2U2G`7 z>GZBYOML*oAO_wZy?N^zKWpixea184?D(Ch^fv&VxUr$TH)hT9Ul?{QLVH`?8U+y@ zX|4t&AfWBce5joBF^qEnEOJS&MF;9ldg~P-BRX5Radch9qn;#4;*<8x##GYRuP+<- zYl?dQJp=4PaI~qQ=s+o2H3yVIg_yCvcz`i~`Y%G=CT8X(lxht*QhI zIV3MnEvAJgE!>l4qGpbN*H6-LY}#pas>tqZom&a9?3Bl^2X#StyYHp{cSCPRhkp1` z!^sUg{-sAQ$zt#IL+%iSx-C`0s%A%`wCt%OGDbnA2Q*!wB{^8d>M{JI~cs_f$Xhu8)BXG79 z*KFHAR%Ngom+YUq;;U$s7&=2xB2U|~;m41N2bCF=ic#f+j`Cc&?8TKq1UXYoLqN6c z*=#Mocrhni>d9Ct*lLTTfsM4qKN}`bFskAF#-AgH=d%h(< z&|{rW{4E3EC{;KQ@6es<=;kPgRWij;t?hzAXmLt8Om#kBP+2=A47@}4m^FJ`yEyiP zp>`%U<(~FATS5Vm&p5E-M1Sak9cI1D6@JE7#`hWerY&Q7B0f$C zhcbpHNP=ZH#ilcD@v4*b z%dKHk+g^J&AGe6(-~G%ZkEE5^e;qk++8*U#<8N_O;)WgcIsg5k!;NuPu6)=hAxF+t zu71#Fy(T!wXKM*q(O&mHsFZS>55_wFTxgte(Q&Z}3bpkdJ#+6}fbR7@OU}ln4sZEZ zr4CE1$Di$oSR5*-%u^@=75!&$<4JN+hCzeH>XOE<}3>SOsX|cr*kd zR0XN$HwC&8i}L|Ms5VvmxMFr&t1Mlt**`=~3LVDk3Rp1u6*`Jk0opET!Z+n$5{+|Y zgDE_l&Vn5~jDW}y-AgETL1r$IhP0ms;wjN;jPY6J>|soj2bjPYvgI@QP`oVEB1S)f z(uGd88qxT2gAo~XF`w?SM68$5QI+v159_6Zod>y!YB~8%_6@=33GU{$GAVYpcG#6r z=vHcH;cn)V!WDqB+BSTL)hUo7v+1eY{kg)Jm4>jc*iotJt#NPfjwOX_BA#(lGCDC0 zEQ7P>IVet)r1Oi^KdY+PJxRwM|2wStzqNZ#3o~UI5eN1Q8BnvXXhQW753Mzv+8{Fr zz0cZ1b;-cXB7P9fR^r62ieQeR>p)6qB3!X?5!95ujguu<(G2cnxay zDCgS6`bL#!&IQgN>iMbXz$#wB2&1g^j>vI0)hl$OLux%be`9EOFMtY|w7YYQzIZIC zb?TocnTaYO3Vk+*K}6-`LP|tgQQ=w;QA`bD`2q3dfma#2#F98?oufne60ct^GoKC~ zR>$#ndp{BP3*u1NTeazFm<~~x^YSxb8=Uw$&t9|(e5ec(u@cV!fwx?zG0=(m zAk}dR(2`3+odDubx(dhx-$ubK{DlZR60q~fQx7BUEk$NH7CxJVN>I&FA_{r;!#L^* z66*h8BAB)FQ7}SyG3msC5kHd!XW6Gvay?+N+ImE@fh00OP4U!k@;r`g0hVkm zJj3^nHt1ZaG2CX&#{VPY=zNh<@@StJN_Ao`2{u@;V;WU*8FxSpPBRV|VedIwdB!9q zrz@f_Smo1{r`C$SiGQWVU2Nni>Dx<`xnnf0C3DntuY+k~Hdt#;^|_1q(8NE&k6Qn- zp1bTm=aZKCNmKW@ckk4CQGtP{i60PVHnR92xpKH}6c<)<=shOI9t1hqS~Hp{WLMiq zM-Ab;%;+ta%=`nY!=d^UqMSp9?$-~CSuJJ#k%3Jk}sF%0@17`#}}70f}(pa(K(< znBMI>!ADORpLcA@k=h4+-_dyG_sdL1qvNx`)=cAlhJL&k4vzxiwtW^4kDBdCxafH4 z*(C6h@6#E(FUzLszL*8i1J}=PBp$#3NpI9R5qre4Yo5G#H8BqLl(2}72+|=?wmNCs~?mv{`W;mJtxoW@1p!p z2v2A1X1bJO9S!dlKsEP}jrUB&t8`6{n!P?=bnM0C?h|Pq*FW?b|7ABQc8EQ=4uOS1 ziP1{UxDkEWHDnN?;p=g`QT^=}`LtV0>*pv?la%uN7xfO3zA{v9*}jp_d-_)Mm#iVY z{S!9LS?&y02WeC{RM|w=t@=yIv-`AONa*%}^?l6&QiL4(G-q7If(EnPC>QKgqn2Xz zs~aQyD(AygZ{|y#PLEARavNiBh+4k;{Ip78a1N&E*9J@hk*j8Vj?VOj|FJrcv31ry znwjNAhEYHmXR(y=QUo>-UqFn!1fv1*Q^HgA>v8v|d(tVRLZ44QRD42vKk4^magG45VMT$Swa5~HLBQR_A)`Ms!Y3W6;`vBjtX+>TpR|Jwm_ zGfgoV8ewS~+wzM7Cd=3S!v4pIu3pCeW}JA}U?zWB@$x&&HxZ&|kA<%-&`KBiPOS8Y zjaG=Fyk#k_NHB`)R&!Kq9Davz4uwm`;5svgNa&j+w5P+Nlpe~ciw#qKsiY0f6b@BkBOcarI-L6c^0;o`U>w_u;2|%3z4DJKM6oFIJ#QG>lRBRwB zGk?AulXV)?0#ko1a%?!Q$$7@cjT3JH6Smsav{w;j`axq=her2Aj-i9akA(AJ%{$M+ zdjk`;{&3j#Bji|Z;EQs{e;#9mG^cCpp>a4M)5dm7#gXGQ5CI2eS7D2Uu#Oe3Hs9sM zph~!;E*u9clC2sykhTz3%?@V5l`VZ%K9v4^8q+Jm zEQoNze9UVGRyzDnHfMEax3Xf;ywxK%#^b%e5P|PVLn8_?132VA5}=(f$Wl`0oCcI9 zfnul6UG0+1zyt~o8TK;5`;<`ce!!23;h8wZyQhZ#2ocVEVaJ5%E0Rp+lal65nR4dX zt4&nYa%Rc%Otu7hS?qmHh`K65200@34j4y0LA}{*3g3%5b+GiJIPfrZ@jk z%Df~--3~zC6rhJh=s^+s1_|AXll#iW++d(?3(&76l}i#d?!}2!38wiumZEI_F(~@o zto*wapRWP(NfJFh2Go-3zL#oH&Lp%o;g%#A$KPmLiy!E_qD@chT?TqdjK1Q5z6y}~ zdmlf8utoO^Q4Q0m3lu~#3DrO;VWy!6ERFA+Fm3>##y%jXNEmRdCgXkJSd3$I?Ws#G zj)g1$mZ@TF2|VHhl&A1ae0dlOpp2S`G0?~k4pFjk#I`zSwXt%4$6cy&%$&b+#?Yau zoK4;g%K1(dThz=@ubDyZORoous<6+2@)fCU-991p$2H?VO}Ts5*w={zXB<&NXv8jC zm7pWMC?nZepe#`TEW{)Vp^_+7HgFGy0gYp+49D@RqcvXQVy7itB&NzqPoP(?Ty=%x zD?m_HG-hU+_zWPQwg}uBouw!Q$IA-3Lf9*D+Hr8nKMZp&08uKXBC^DYLw-mWpzNw7 z^D0Zw$43jNeY?he`$=@}Mh|XLaJjWm>x2K75#QANzSj-JBhjsxn~%{v11+CmB*n zkEbr!!7_*Ld)8>D*ZLj`$%?!3*D2VFCOf973V0$^R_6 zT#V?KU7Ou+?vmXfNP02~F;}OhWztR50RQsk4s^HJSL*|U6MpBN-|6dgf14I_YFlbTEKHe zQx<&Mz_U~gzx>j6uK?C9h#Gi_#u%b+Y?vH<>EF5Le@l#6=Ax5d`d0GM*Ko)dVM{cm zY}bGZM~IYBZ`jjtmMq?9pdLHPNZB(aQ^ftQnNBL&$FOph%kYZxhben`{yurR&woJN z5`O8-48&OWj1F#~sEpEQ`htw??H z2_q|V@3ByND+S0Y<{Y>yFwf1R_fnDmzNLqrl&pqE37FMmw(*UD@h$z>SV3U1&d3zi zHEGlOy#f#dut76*F#wx`lA-b4qZg}K%AWCtGO2WJ2-_z}?pUB&b9L+SQ}gY1D&r1r z%|fAV(5e10AD(^Le9$?Lx_$GDgSTFs4eK09R=;}vY%k5e-EjT|AErYBsq?D}$kmBV z%&FYs`=q;db^QxC?ryQ_?tGYe0+dPz>WUB^T;=nD2p(%|dKWm%9z+TNsWE|iVz4ah zqJBVV-)hnB=s9S2<~FOTp49w0v`kWSEb>n#w>!P#Am(5@dgvE=Kp6F0P?&Vz_f4;& z^wBrWD@p4c5*9J=l=R2M19^9WNymJOo}1=B7oiW2ne(z*Q*TCfg=VR-z+nhbS~^ht zcS{@qd4|zFmWioZH7cr`$50L)o*$Dc^Q6F#VxTi7dZI*dU$ zYSOtL$k>q{&W872C1bx+=nz!PdFuWjq0J@t;9L?Shq9{vD)WHr>d9$%k7WO45vo&! zx`IOw1fc)tiTY-P@E1ckI0Qx?(E^Ap6`|0dQ5X3uy#n+VKui6mm)(_gfdn4GFMw+bKV#6(0DR`U$}MBCU|Mv=AM9L8aIUq}BOzvZPKwA2K$eq`5F=!T z1TMqMa;S*P0HjyxbRQKN^cqnw#D3y`x<9CSF85kTBlei)HOBjDv>3qqY;BNoH1s-v zN&o_d&<`X|(F41}eSvfFHEk9cgi|na)UL|(^EnsRcuU|1X0>v9pz0ja=&9qC1A+XK zg!)?ceV}~%70fNK&dZ-)OfA~i97t#+zibCOJr`msZW~qTAQ_UnjO43$A?&3{4!y^2 zN3U{)NI8KCwCLm+x!#MPlku%&9bw?Epj@st{4fc|6+iBhpqTM3+~wZYWz-o_X?Q&U zBBdndSE&GpDy+NNHSND*eQ5MAPW@Sim|8h&d8SZ$xz7Q9*CFpIw|f~u{K-wDcUb={ z%tDJUq{iqQ=WYCC*A@P|Bv_=#?vTkO{x2=T5@87Vs~ooDosr)mkBc`R7SQPxxTcVi z#$p^Il?G!lV0)QBp>Wy15Z(T_rjqa(Thf3lavPi$t<&ra(;x&79B(U zmTmAgMa(M9XL6SjB0FF;M~q@q5jkRb7HRSyz~2)AfBv2sJV{}&Cehc<-K>*GeH}q1 zdDETc>Am--k9(8~gs^xa_%IWoMv_G~(1nZ8lA{oBCgiYKHg5~}7Q-2TZd@Jmt;c;Z zO~aU0(knt1p96N~{r(?dJ#pTWEOB(v zI#KmT+}o;BbvqIFG0^{5jry&@e~{+Vtj(j#6DOYTOGxKtxi8UPh$HbjUxKa&A55>Zn64D^*3@PmalAaU<#7*h~laZ&r_>Os$;VO);TV11$dH zl=r><=&6!8W$~^`jylZ%sZsk#MK(Pw2G_cW{Y4&2{LnrW}Wu^%$|-c zD}TG>WF8?jrT!p)kD}(mie1E$hblgW);7)VfyAZZXavx~V>8&N0zj= zs}G{GOyVcIa}IY+=swz`&stDMB2ng(Chm_fQghVn$~Q6$(R zTy2+6hcyvrvVHbYWM_#Z&y%|boCB7ek6PFs&&EL#ke3NoBvo=rQTR>piD zyKt!TOh}5ZClVdp?gl-=$>dP_ zOjAf`ilSgw^I`3{!bj~D$`{^$GcfeI9~SzpxZdWv`BWk?a=A;TP39e{sYmY{fT`Xi zjae*a%GuqFv8fMBjq`cyWt%iZvfua4-MZF+QJZ@GrQ^e+;(GDDV| zNXNpBKnMv&c*Gy%R28N_oRMqhk=UD zZO_*kGi2YiR+azecbQ*Z$K?C!st72Ww$p((Bir&kEGBvOK`#Pp3GM1CL+^Mgfd#t* zmS|t_dFYdQB{FCglqktj?i2D5+r=OiA&A(+%5}3x65l1vZh84GKfrqE%HU*>&Rmlnd;FCFJ4HMB?D%AYJGea1NA`uymbq74p>h|?p_eT?Q2RdGT zu_}NMu?gLPG2r`(Wp?(-!>j4 zPqPPUE(s#_ifPznikrqA2F$8&Q0@re^cH1pPYqU{E2Ms0H~F0M5JuO_1Cj5tgc{Xq3VM8Rb1e0AIg)adDU{lJ>G zu7d;3hjlE}yacs`mRODAj=EerNDZM&w9CpDk&G!|DD?QHGe zbuy9!RjJmn;kZ3`9_4Aaw-XiC^xo$^4dHRV+757>qw*CPXYTk%duU)y*-WvGr?HNy zI92M7^VM7&GSPwEFq=}J9Oa$3(SU7E^9cH;@M=j_i}4EUu@mH}v#ijp@L-4YEk}rs zb`&mwONE8;fy6Ef8=udD;P@hC6s^hHz{%&`iQKBcqr1J+3PLs--|Rctm}|n6)fFAV z>Z0`q)u#pN8!=@NdBp6mv=fc6TmHZ;bMu2QkWHi?V}rmt=HMTYP<3I=gZdf6U_DuU zO&K#EtH0Is}gOo#0%BX17^yL>W_df;KwnB~QF>*@zk-N?7-$Aaz4DPNlF z_FjHrCv`(dN*b0#{Csn!JX(aBUQ|x=Jlm+S_EQLQ`CwRn!j+5mUQkfR9eH&Aj1DSn zIqboP7CxAkc3%Eu>(S>@)pOf{ijB7eCC|jiN6qz@x2kCU^;{C=+7qy)rQhQJ{s)Br zhlZ3U7_6n5L%03NzgzgCc=kOk8K)7+Yy$=2eq^4LYm?t=-wpL`{}bT4?R}lw#Py5@ zQ6vCpyO~R4LVYq>n(IdLn=JL;0D09M6AHwbq3l5t8n;1Aj&Jb}g_!q)J(!pOv|Zk2 z05;|4-26~t+6RsjZT0+D!E;d2YXa*pAYlp+s1Er{C=$TX#A$|S3OT$Ns?x&wF zoYT1Et55)d-5(+Wwfl~$=Gkq)>^$;XiFx~5`;sNFRDemEB)1i0$jH3oRL?a?D=nN; z@LWGV>t?Dh1wq21`Xx2z!Eu(0GSDa~ zyh!B$3zi8u&tk#F!onFL0(T2R2V7znHdu7}?se{=oE@MLl@%;wCY_7d3cKlq1_W{- z7`-oOtZ||?3UxNM()o%**AX%7_MG>`=Dj00tE6cj^pEmh0Dga{U}vbzl*F^H-AdiA z@J5F;*-j4(g(A0KhfYGn_|S{m)uH!!*g|MnC_bQ7qe26p;$ygHO>vJ{?#^JHRKdcK z%i!;?uQQc{jt)U10l4YT)8{~%$$&hW0w}dFSBjHQK|=#1rwQ-d98w?w!qSa^@;$=b z#y>Y!Sp?j?Iu0q1bkKBqzfERR224QzyoD+iv~B)d#vg(v_nqGCQ$UW?-JZF>M*v6- z(Dcb?5$t7@)0~eJ7Zgmm)-p0Ifa6BW*HJ?xAfTqmDud+9U3FW`*D+oJV9NMKOI6M` z8Nv^j7cqnNk5vj_j_=y!`qZdzYjfQecjlu(CH1mA@Z0HC2CpIu8cbC{M)Q0nCN{Ml zi?G2C1$ z+eZrdcK_0D>cqY|=rWo&M^-@^R!`y~VfOg28s7JDXivqCjX8Bs1+AqpBMbM5Ia=99 ztNgz&ivj)T0&vg9;vm&>yoCG;Y=7=6@&$>($T|~Zzup>dp!p+|6aj!^7?30!Tqax| z%$B(TS&HYqkbNvqRf2WE9<>n_PlDZ_-z?KPkc_rqBO!S7ahw1gClZ*V3bxq<*$6@2 zHBfPll0K+GtDWQ62j23BOZo$OV^p>a-{9a`x-%A9d{SvoKt1>bFH%@)TC5P9e|jfh za1NI@{ZH zwzYCA8X@8Rqgt69Te_*AfdkOk!KVh|!^HNrL1IM+mPU}ygPA*_Fzd{-8oLTGnSkIW zzjZ!^#z+f_g|eWh{Id1yR$IG{i`0bP|}JThLC zp<0}lDWmntcI&oOK*-SIT}sjU9VPe;iSh;@r2ZUEqH}I`Ug>`ZnoN?$hU9$0nnCh3 zG?JwWSd_lnPvti+)Xlw@-4Yeg_QR48d9!$U5+yHE3?|yHgDDMR0c-@LcD(JH)cQ z3W9yjEmnf~2kbb~ZB~SsZg!uv@OVwgJaUFzO(xcNo9ovPSmK3?;{FFRQX03>vgQkX>rPRtB8Tmz?PZ z@O!4b#U}%&PxCTD;n95fxzTyMV%Y(c`+i|1;^9iNL2fdC8^7>PG<7vv$cq-=j?8aw z{RRyo@qFk!*BZ!P#{QAoT@kI^J#&y9H4qseVrw77fW>C|^Za7jewl2afWvOAclykC zEiN1%0N14zQDNL~*VoaPj$KH&?kj}ILFMo6L-@9GeOvQIyWlz}o$tpwr*4l(h(>RD z1;C1sis>`w0M5I=6#z0J*Xa7wD+>=_3fY@Lq_Y~?2acyGb+^x)SkD@sp~%UD_N#Hh%j_jbJ@HWA`i&D^pCNzim{Il;mt zajt8>_`@qtcq(guTK?Um62GIYH>rq(&HrF&_IXD$^EOx9Q*rPcyA#qf;i+>lwGX)w zs!u7wrFFM%(EDKCnM!X6T-P-)EtH4LsuZ7U&xcG-+S22Vp5#ui~2zSK( zbJ#PA&@&KdzTAh|RS3C35Hu|Jw|Yz#OU0FY1v0J8w{2EfA(vhhDb&+({JDMT zDC{D~Vr{^_m19q0yNwN4Ai${VYnJ?Lq|KZyI1pJ3B&&kVNo>^`u=aG0>Shjk8t5hk zO>3|%sHM%R;{73a{vgU9{myk&%bj9>{-_B3N-;2-|M}+fRXt zQu_8&a$DKF5GpUE1`;lYZWgsiXWAbX@IoaWJ~gYpv{ajmTK?^?VX8){nl{Cw-90i+seh=gJV<&xW9byh1}gVoJ8YH8bu^er`e|;36MzJ`_;W zrdndrTyh+eNxCE6CimlLBNhnH-h3|R?`vsHV!-@<&F0WLt zK@)QvlrA{^=2O6TXC3xOI^ZbY8A=(lwfzLU8A;wom#;)_aqjfY(G~H zr=F9Y*L#L^a1v3Ij#SU!7dNiwr+p@SCQp~`9(as8;+=Cn|9VH;$Eow{Uv|zA|NUU> zlKeu>do*=#>cZ*k?J4f>J|FSg_Hw3VSKGq+N86gTZ3BBhQB@%*7KHralA1j-qGtBcBrN^nMW70d+9gBI@9dnoi z$EbyKZdFu&1P@Z3KPJo)rY^qEqUKZBgOnhOw3=E|pLfTibzXXhwrhkiQt&s1SYut@ zyEJ6U#nzNj$BQz~C)%aC83SjTwga_cuK2r=>qRst+_kSRqu%>f-3gTIEiGuw`W6jb z=Xy&HLNEQ;bL67y+%XjTX0vM7SFhnZL)^_*fb#I<*QM_J<=^`QwSNDB-0Qt=*L~M| zsdD^dJq_37ja-d$f^6Tt(tcT>&!+Xwu`Tef-xma8fPTK*QH;}&s6c{DlY_+-Pe66=* z()oHtx{2rGoHvo{J2$+JODcMwpO}Dtz0}xcf6e!Vw1^t_2i!X3MT)w5X8~o?Tc_=b zE+LhtK4+||ywK!0EomXg9`b3vmmzth40Abk=4zPJzJDJY%w1&s9<{$zs0?%ctIrZ) zSC8&!ryW`7uCWkEp1_wBNQ@EYcOtz;Y>#(!^wTbsF5N(rW9RQ`%f&Onv|`BoTZ_$O zyy3I{;!Bn_DXny_c~HCXy-84=W{8qv9R@f5{Pcg4-J1s45M1iQb!pugQ|c~eqe%1R zl*e3{OBx-e+b|+R%T-Xi9w@h7tt1C9Mb0_Wu)WhFbO1{~eqlj~VkKE`20#Sd1qTv= zMy7pWi!?O=)+fc4@clsna&Vkp{WKY^aLN(F-6e9A-luIfNJFfmfk?OU71pe7gn#i2 zUVw+t078T>Yf9H7u>ZY}hpx*$D4ma^g%A*BJH!Kc?E?J2L3RvCIcfGDU>}D8u1$>w z;fdQJQHK@iU5n)wIyL4ynewjsiO?8^A(fb;yYHO_-=Jm0ZL-z_#k|FoTTvT3GGjEDw8cK6+aPIjE3R$Pn)Zv;t|{%oUG~t8Ta=z|e+{uDpPCNXzVcg5 z{M8w9JEh)2Rem%6^H)@^fj_QJV8ayiR&NKyr4Qk&$vf#N!l;S@i@WVxOvh41$PAZapGG_{pj`6)Xr|EkL=d z1cWpo=+!>M<)PEYtZvm$zPg;_@Ge6EpycdVE zh#mGrXBbCR5C)4k8r!n8`D`hY(JLH{?5k)(@@h8LA3k)8I%u7gRMYRAV{u!%&b5rb zxv#E#<7;i@HFBpbxBuz+a|M--Q9P$L?UhEuBB*1b#^-X?g?6W|5QwF0sW`#f61Nip zS~!adaY~h<(t|v^CbHV`Bsyf-$Y5b7u2^=%RZ|?1LVR7H_?Ok=?(yIY>5(6IsBW0o zTn|op#;Y%qluv0J#@uh^MJ1#Ro$pHs&)CYYx!MoqzGB_+=JO^8#$HqZ7FWkqKO^M@ zTr_?;9pR-CdTt8{bhtTE>7%>&t!axO!}0W%N|&yr>&Xcx9C*1-F8WMceBX8-Fuc$y zUSkl|H-5f7;lASm8}XZ5%jKu~&M!3Fsd%)ow3NN3IfFz}IR!!G@a1*gtH7nD)zScMM|JFOl)Wj`gdIYU7{=9BQ4m z+-8K0t2~!F@BHG}^089f4q2N7w(JaRPKdv6KWi%6L0#c7k7#TjsK5C&jRoM~R+CTJ zZi_%LRQb4dowu9ctxEs3%jaI))&WC2$8apx^)u0iO0!`e4BH+Anh|vX{oPk!Z89UT z{Q)r-_xiMooVoX0TLayipx1q%ecs&m_n;1pq06+=5hB>(41oB_NAUnqNYcqb`KiNo z2(f=cJ2V?$`l3&{MhU&j+~P7>yxBfkAZpONC!lwI-Kb$JgR zlFPH1?earY4zuF!-+>r4UmV}tETo`?lsY!9LD(r4n)V9I zBgEwmY~13E>Db-M-eAf?%B}{sw(+hwN0F%j5*An{O$s6K?;{8Wz8I)Ka%*?9O~en{ z5ef?J5S4M%tO01;Dm3e5`@IsHoCNB7?rmv69AyHYCNw6A=)JlT<4gr*P7gGGO-^H* zbO}wp#b%$_^#6kYG{x)|k5U9cyqcxY@&F?QShd_>PJkjQ3?cxpZH^h7Vd-D!J^^-6 z@D8bfL^iyedL`TrB+Si57AbJk7C954UKRyord8 z+M}cgcT9vt+9vbpMjdIx@zxV7N+{;s)i#ywdw1hX`a70<0_K-e&j&mtpdadc;$om^ ziR|wka};8)!eZ<|lQWP+12*3j{-@d0yw;?pGO_u6tP8Zo38(~S6KM$`8#D+M53=c3 zrblr+giY-~m9b*5Q?b7rY2e7dayOOIRhJg-CKo|JfNmCJT;(E{LatjJ2o%8rQ&|3Z z!97_7rY!=1XJKz8S}8g5m$t}_Fkg?3fx)38H3LWN&5{ej8+-jtN0U-!%l&B53s0X- z$OppJrWdB#$Zt*0|L1R7LcerW({;Z*;tZW9_us@YN(PL*><`>f=kHot=|V9YP(cU! zD!tnMdF|!v#b7TFXksQdZVcl*SQ%?yZ|<;mJ*6sb0ptvjfumCRcEMsp*3d!OJ@0*Y19)2Yjyq-LqWjsmsuHFA0 z!#-G>MQv96j7AU=*|-sb(TD&YAT)wlym=46>A}2ww&h8n);W!g5W;GJge*2r%R(*+ zaGwA*;-lq~)-qDa;ft*y7IL_jg`JZd$N>01EL1W8M`EMT08kOHV#nn*?cxYgs`h$~ z-V35t=nJssNBxd{MW%m#v(X3%oF@RO#T(OksZ|Q3`Lr+%3Hle@g@2On!N60wLrcAO; zl94TA!UHXh1j#F^D%j+H&3pDVMG{g`^ZNH#Hl+!HD;B=qC+zIse|k(vXj1%M#2GgL zwV?=;93eRqsLeg`9w{^mP!y)?k<4361RF>Va?=JOI!95(AECCh|05qL7vtW?yfmhU zn2acJ5dv_Re2r3o^^SpQ^+R12dXRFFjzTD7^8qWQif@lzg_rRX0|cUhZg zhF)rRxqH(qq80UknQY1L&`(ahPueeR7haRui|HjVl-62KHBs+U_P;X8ah_JXjwsj4 zMXqTxb4roiheKMhz2_YeaIy{)G>Gm{-zFV&1Nyrrt8WMbGszcRZy!GDh8}RLQ4rs*hukA3UoxP9p+n9dfw3<{KCG$eLfx=43B2WDQ_s$Q`E8gaF4&pT<|3HvDa1;y zC>=(W4jzx~q5~bfU>{g_-dzO6qdE2-BFDT1hd1MT7Vtmv_WBm3U7zyH)vd3c9@sTu z^rYari^{>zL1R85nm5Qwm1S)G(yfi1ejjKoCmd&moLCgb{KoZ-hngg|h-zwpBoA7UyVt^bb^AQs}_EtpZ_|={$e9E{7*pq!}?EMLuIgujC8DDBYyB?Y>>^| zfAZYREJ7jMT&|#U_ghrDeVcwzQrKuwH*>P?8_}D6cj6n7!N$@QCi0ZTSM0Rw)BTUs4S%1&noYG%|cbWTZZ`HDEsW6_SFHnA36xg&>H01Sq3Xzd9j{^R#k7KA@(Q(3i!MnXd*u&={l-=NU&4FEf* z6X0g$1_Hp$iZ?jgv3Dze7)qBDUVPXSiN6Rt{%+qtUM;V@W{yRTo{UJ@B}cI{KC&>y z$Bw+P7EQ4&a?a%jCM++QB}P8@fh zf1R|Fw8v!J`9f}P>vE}2ua959&+C;@3H_X`GO+t$QTq|vV>JidcI1cqzWi69ElQ3H z1SI(>vf`jd=<_%xgQa=O z!JiyRR0GYw>JMc;J)+&0Egt)6NwN=f_wSW0a0bKcOgI%|2;IRd44UGYRv|txgfgKl z&ls5#Eu3j^;#3*X4^$$iZ>ZUl#Sj89WZU&HaVRSAC<3FU!x zOE@Wmw1tjnfG1OhkGsD zhf(|K?D&LE_tRK&-;zIziZ7dSSl>tyjn}az>gS5hqxYU#1>Zej>&^euv;Q&QnYmrhQJt4AJHF?|&(mwyz0*vY@2qNmcbr>r z|BiWa(@}VQdh_nnN3%3PJCAO=p?FJ9?pHv`VYY`)9thbe9c&CY2c9H`GOoWN2UBXGGVbP^w zej%Zu{s&q5x}Nogu^(YM%v8c%M;BA7T#6fn)l-Ao6!g+M@j!-URY<-4p)saaeOG^< zz1Vt2l8RgbI{aGhwxBM@)!kuQHMqAhQ_i`L;ZK~l5YfO0a({y}Cg5VGrrIejS_+~R z&#~bq83-ZVyi}Eguqf5iy-xm0F|ZXN!&6bq%Ca5Mu0?7g9N+%ish8sN=+pU0@AJGW zin8rG>~Y3xgRXwL1E_2CkLd=)lV3B^OQx}pwuq02Tk8>4qL80^PYs8RzI}P`3s6jr z=d!C0X#c|AcE@{KR3>gVy&NAj}_QuVwvEhERV<*-| zCOcU=Sp?d#k)vkme}bfo9`&0RM0=a$L*{?w?cG{fZl?F&c&WS36?J8=?OD2Z`K;Ui zM2$mZzd_5Tdfk=tJNo<+wTtj`EtR2(QUmkKrl|{2HW#usgs;EnJ!m>6gL(F8a5fzf z8y9OBYC|R>-J9b;*bEE$!^8ah;jYPJezkim54#yZ;C4O(c6e?c5u7=IFJZx$ts4GD zfIqHEzO;KM9i2F$MEBd5Q$|>bc-nnbh8lWZKu5WQImTTLl`bQX@;+B2#*a4~f8%t= zHjfpxrcT!GJ?oV94i(_&|B$(|NsJY66U_cFAs0q*?= zPHga!;`Xs1rn@uxO7*SSz^XHX2W!DehE%)AAPW{6|2k?(mKXBo`$zr?veqdrfzsi>O zgOx&qgy@(zgSZ{!qs%Y%Oer$1IK}YjG{io^4V6%sVuY^=bc)hIEdQeejEeW>g&jR zQxG4FbACbFY~8T#FS`|@l!MfLG4;i>XX{cpfo@tytTlPS%yM|xr@tG;6Z=yxs)n4E zZwwOpETg9nJYL>)pEl zogWi$Asg??;4Z!N(t)_|NpJYoR`Tt~Q)M-AEd!N4T@dtu9%MRu7`m*1jsD>R`AtkNw-~ZH?1Nq0dT-i*rWeAayT;=3J31qnwuPj{s2`_1= zSmMNfl1C9mIAlYj@6i$A3HUu795~EECnKyrGsN_j)+~85M|$F{Z{!PW6Guek?lVJ9 zMH8s?B!5aHO8~RL1Ht+kkXMczqBs9IY|sHqILBC)Y#t{7yc%K?~3*FGFF?pdOo3aTin!!~$U!dnM(omonJ$ez+F93H4 z;52Mo1q&U^K@GFeplq}biI}1FJ=S1NdB|sxurDI4tOL%GqXaB?1q+_Sf{W$ug@B@H zw<4E;!mw0LF>D5pwGosfBJ?jie#be)ErB;5j2H$ufDCs2L9#>kb%cpZD=f2g9{{`s zmTiR`S69EHvns1&xMOL>3K7kXCBnjlxT0_zKw~r(gHMs-e9ssOtSK2%!weo}Hjf;~ zBgcgpF@Z)=Y{Li*>0~ETJdSJyybn^)4$G0UYWfQLUz93i5*yU(WVL(p z3u{&l_r3@9J`WWDdaDGs;{t+*&@e)1ltVX4V24C!8rMq6TqzmSWm_qwWJ<|{LUO%^ z43byLAfT%dAX*MU0Du=E_F#E9UVs_Z7}rawwF(kfXjG!1<^bhl&5sfK?Gh=ug%OE9 zMHvy+w9dth@~HFS7Bg&%ak|BzBC?i!>%cautVd^@2{huzQpHj`)PCaxsZlZe8$?0U zgAmhrMqyHNHc@B{HzLW$G;2-{t<`-s`!%w^ zN3qOG(bU2-Y1I&%v?eVYLaPR^CpS_wc!9gOu@wJ8f`2UV-gnSz_dzfB27)Oa@-i9{ z7)`Y1fw%F%KgYml3*d=0L;#!k0$}`t7Q{^xZPcQ}Ab9VS1b~$AtPY(NI{u2w9bOaU z3lsVVV8`T`$NW~g0EM6J{c!~Vy` zKFh&zFmQ9J@2Xxxyb8nARql^J1vCbN@I12~c2bvOjMxh=1t6;caDf`Mkp?POgVQ8F zy#hoDclXxn=Ywi^%8BP^vJY7*OuKB4w57uPX_EHW;Hq4Zie<>geI|&Jud&Jeg7fAHUzCP8ZZn}7@k9s zqS&N`#v2RWeHl`chtzOUV(3SYNB%=X7#VsBv9sLxNU32v%TR=k$HbQX;AVRYai3V( zY1=>QS_SwPz?L>SdxnOrtEV4PqrXUSXSNv?X>Me&jf zo2`^;$^+WUfn+(5DgiFt1_a3g5Ivxp1}M!0LTCmfbmI{{Vv$WLk(&D`j37YEY&xkz zNXZ!_#nZ|08j^gqafFm?ZK=m?{v7Esp3&%mZG)|IKPCgHB|^I_UJX}cR3aoM^XyFa z|2xuSyOX_PNWb*rlkOtc+@Y&uWFdhp zY_=A9`8W73$Ae!Aa#wk59DZZ6RnvUyJB|rW&WH-yipSgX@GjSq`MiyK+8;#IYzxG1 zeBZoigNFeh1Z~B=kYEPom@i!DO*I-M#Y}6*ohr}~jZ>ps)X5t3q-5iWWJitt@g)1n z5&8D}EL_4}++%^S2@l(+cXKT)w~?CV*eq#bqk1%^J#}V})u0<1(=LT+pTg8yYHq1{ zFao$%aR=@p2M2KBCowQvLR*Oh{^c`iC(FJ#gJ`C~aTG6x1#qqe)tU+)tv8`vPrKIRx+9++=ElK1dT?!*wxp{2smiqE2=s9#!#<}@8yH;@e znbFPjDF_F6B)dFv3k^0e!7fU$9{^c;WVTk09re#{VnIE)P^KD@&$`0Nzak%nGS}8_ z5kQxgLq!s7GHZ>`b|hc^zD0fEm!_ui8+t0 zSuD$xb+L0r0dv2&bG|^5FPq{aoddBAQ-EgqY)S=>%%?jGPX@*FQrAhzahe-xY-8Uu zk>3M8@D<_ zNG@e>*u^$R?{D0pAT=dY`hnCUx^bM5nYYwv5NM{y0mC%+^u?Uv_pGyAxX2h&|0vgl zP8cCeg@(k8?^qiZ==o#MEst_0$hmvI=WaLje$Yag_T=9y^0na+us}k|Sv}f$t2q#O zyUw?J4RL&^K7(wWpfHL68Z&_Hxi^RbUWLyCJ^W93e4R5&mL^kqc&?PN(z?%+^I&!i z^DhfMCRq7ajq8s_MW6LaaYpqDF!lCWg&fhxMTn}$_(Df@Uz=X7;Mmra=N<3Mjit1C0UDBa+s08G7}+1LYe@ z;i0Ev5*tzpVJSOWwq-%;;yibKeLb1x5p9W^6HKZF*p%~FStB-{ z1x;W<+y45?-3k`)@O`oXI`bJ`D}j1X8E=pv4hRr4r!Oa|_nLv3E(4503A9$S<`6d# zst9`9Z5SuL(PU&ueu(@aM_!CVzWB>>*LMBphi|9h=eN?T#|<+yYyXQPiF8wa5PH^c zT2_{#FMDmmzI%OP(i^QGNpHD8h5*9y0R`-AHUi~lGO5nkbj=?bPjTjHWEpS6@1^%H zojJ?{GOwGoXVEhXfP-=&b^LIn(GSxqAXSDVW$u2X{UklN;CrsmFZ}$~obv~zSKK{q_+U3a_|B=xL%n)1H^7X& zf2+cDeB1th-s5FZ^J8f91xM0?;@@NX*xeGqze$HP>p}Ho@397@IvzTRxxt!Z^fg(h07qBv3VmB}$Gpd;geb zNs5g*Lx6cEG<(czD54RcNzJYTLHAk2hjg?4tL7>JrcVR7yTim%gX^~kyU4*MG*A&2 zl_uZPLo2xpxPPX^2Q-QpQ=9(LQ|tj|V^Ts1J9|-%SjxqTy>X8_P{?JZv(9wsh%vd^GSl)msD+TqMg^K_$c444B1!t^mP%xdZxuv z68+mgnFWqj1Cr$ZTWDydpkqeDi0VWXvUJcm2~-4tPj>b%L*2Z$9F{GC4w}81+`LrPN0L1{0*zBG_b|5Oy% zhh7JQXdp<7y&(%F;u(3fm&TNmy@f_J<@J1p(XD0Vg>F*cenW2!DS}5lw#5*l`5GoP z$`t;IJi|P7b6m}(bnBXlp3RMx>;t_K4^M2qcg!U|!$X?3M^j_Em-Ahba#xcE{85`7K z>HP7>J~OYJ7rl4iJifpF+V#Jm?fzYbVZPWp4I21;&)MbT#y*Bd2}QTExVQU(m&Xjq zdO4t-j)=HNa!R@y^vK|I7}+uE(#o}_9`Bn^th+|VCSKM}HT+|$75eY|zDl(K!- zyDu+KDn*GUbm=FMPgW%2b z5rY)}x)2Bw}LujVkWJgO9HweH%Ewi9`q>uEJ}SxFYfd2kSJe7M$+7EPpoc7x@O#1>ArJCL?f<>lW8iQKk%DqQuOf!q9qF2>&Ag_bG4Q@2 zd$`x7e^6=QcmU{*b-GGQ4bK`BTVdV$#cR+`h?+zuArKGxX&8?^fpv0Y0S6G{iIo0e>DcN;V*`=db8Mfm)n7(vL)u*sI3?qLgY zSKf$Kbx3~5tNniSQ|`sv2i99{Y1}?yzNdJ5T;kV#Z>nPUhrH=C+hJu>W0m-lqnq}c zvCfEE{#Nq(=ey5>daUnTak)NhTUtKuy8l0AczQG2cHH|FrM$-t^b>?)-dg1BO_M}& z{JwRD1pT2kEwDqZ_CP=gs}&WN2&?+S6OqfmW-h6Lf0jQQFSveXAaeOL|C!+BPraW+ zS!d)N954Eq=0CLbvr_@a>-i_3k!9|;|1L0^XplI;^Ic;s*wRcP+)oQcr)h<`f28Ow ztYB-Fwt5wshb^R4Sj4lc%$qd8WVQdnCXWj5L1_S+Qch?|gSbT6hZHbH7JZNk{Ia{j z9^VN2icebk0C-q`RU_&S3y!s|DfjAVq|V{tR%uzaiAkJub%G&7pbi-5cvnFbt3cFV zZe2>Rz#xeaaG&7qD~uLl#S#Xomxd@F6-qs}pIuUM4z?Qy%;RU|4mU)JnEOPHvN|D!aXl|g*8JY0Sz!~pO{ zWG&|*qY^m81I!9bYazn*r;s2EfOyxPLu^6_p$xSE^M4AgX@UFEvjT(%7ifS0fK2g% z0-?3h;Ei?ti!#XS!VcUz@M-f8@@iYA(C}C^ryNw|y33~v*UKGpOcOrPmxGOtCJZ@+ z@oHBuq>yhis=VS`8lLU;r{5MNx(x~&QmBJE<7FQMKNSlnCbS~5db*NX9D}MGv9!9< zFvKi!+smt2V|a`*#Q3G}(-bQWFUZz8LM9+S-cnrT+@ReefXPPB7?v|@tN;`G__qRi z9WZDyD3SkPLU=24z|=qk3u<5!zy1t+ZJoFGWXOztNY-BK@%d^_mj<^#Gc5k0W&FwL zC}QOGA?L1dbzv{Am_3le-u6l{H%d>KRa~iBrnhK?^<5#B-4C*+G`57v?ql{a!88|? zI8@HLeDkCE7i+7w`)#6P5aE6Xqgx-09r4WQgL=MdtPh@bHGf}RZ91*p&u`skE|CW= zjy=H!>uuHJiFbUH!;^m`T9I@+8uurEPuhF}Dhm7F6ITMcmH@rx$4zM(bGO5YT#U_=SP}9CU^fd5 zoHfEGFQe)$J;Y(RaUm`E3W!8u+Z1v|ST3Fs-OXzBkxX^~-iu<`OQV||K-3ScV5>1? zBRgR!8EAipRr54vS7FnERJF*c)M3r4zUJ2L(RRyiMiHT^Z^WE4)hr4-yhJrnzNj8u z{az0B(wz|+B}9W)PjO+}h$0+q2xL?o26o1aNZ$I>C!Ynhrm#uF|3p9;+zr!c7C`ln zC@LZx>;9(P?o$JC*PtgAN`O@YPzL;jY#pTTeyQ-l>!X#F2IO zna@zsDoe`HVx?=L5Su@JQK}3XaSeEj5!Vk;uKgbJK8K0rby@0#UGJRI-1f!~rjQqO z0d`?Aml88)j<-_o%0L?A;gS^m@PW#6%etLbFYxxURQ~3}!-!R_Y!nVQNKI0RVa*D3 z#`J*654?UB3q=1nA)+wCVV$~J#d#L&!GH2ZwyL~l{~0vV zOGGJ>fsJF;Xdjk{GKRO7neYNOYQMs2GCK_IH3$}^!eWtmiuH;9VC8`h0)H?pxNDxg z3qV;?#~AXN6Ap!^*_am(o@Y<;(96&uQmeco4yr(~GV^pty=oxK^YZU(Oa|e9u)>QP zzzhkn$up><0~E%$yGDZWRVha%GcL@eHpW8S0FdMJ;7zR53!Kzt@eqEML!dm|D8W7y z4{3equ=Q zd^{iEEd!b8LmLes2G`aD548e-%$Oi6b%eD90LBB%Xn^M;fB`G=-Xy?)3vgCJS%Qk1 zx3+ecq5u}Ww->Srum7UqoA3w;(^(`lSoH={s&*=s?B2~ZC{;sd(hOn?3_4}V!B|u$ z-{3qKbxOAN4!~V4K)qrT)HEIT3LXWfqscsjemweWH0oZ40pp)7r)9_sYNUd1Ac=OH zu|^%2p*rx}jxkX*puqrs*|I5v9$SMByg}@1k>ij+Y#63)el>UR#sZf_>o}=GQgWn+fWX5{)dJ!A6UyXgjTK|-ZnAWcU+JL##fc+S~ z@~aFZ&^2H_$}k@#D}`LleO~f=IpPw2Wnnk=!zt`j73Mhr^PG=+3_#w~A{w;lpc<<2 zPV85Gx>bO=$VY9JK&3J;SqAr$fK&8t1pF?477P;r38sM~m0&v#By6cos)Xk+aM*Oh zCcJ^ZJPk6x3@&2sMk~Y0er(!39OhFIxP6YcMik}C*;OLp-BfN?EsYZ(_V@sxKjF}5 zbs$>~iPrClgP|*0z!9ol;eb8%Qeu(vNLmGm!4I$C*f~}>ILbsT$&4)^hA-1DmG5}} zFgVqEvokH!MayUbG3=%pD~CZNcO6_U9ScuF!ti>Hk)S{V2^Xc8b%yEC9C(?nO&JZb zPgUV80P%=mdo8$h0UW}GB&)zFD!p<9NCB{oBtR1Y=-AatGdXZhnFVYBpa4~1nG1pi zKnPs)ogrr?m@yBn@&Ip=L76JByBwUz1Ut|m%c5bX2*xJ0Ae&#-O@w?*++t5ybDy!S zO(m|MMxNywC}b!#Gj;)HfOv$wjz?dnq1AH2c>(6uV}t7|R1Xb(oQ9hJvD$oM>%GUw z-_{21YhCFI_y7|%)5JtxW?tuSn_ueG(SRC47);|)f8tOV@SPW!h#Q`kH{Q|$1mLn^ zh)Bz2E`yhA;TN=yF=oLGLu;va|?o!>W&3bg-K0_tbn>^ve8h?25(Wqu;T5Snzv1_NBV%F%$7j zhV=_I-^0ax<{Hly!1py^Bf9nf?-3mmL<3iLfVnnc;?i}3!RZpzSH52Khq<=h^j7QRzQM(mb_6=;#+fT)+^!+Sby!X)rAF68HvUBA5S!qtIVyCYYS z_T5=_&N4`_yf-`=zDux-1%;%Lfk1<4pANXn2Rbw~Q7=qZ_Ea5l3l=ORQ~2xqO|6M{l&j5$#^xLZ=-NE0GiQ(XP_-;iGu)7%k_+g#CZS2d0xM};=EVtRkg|`6=j?0`b z0I=QBaiKQ7oO@@Q`Q6SwNS8Zn?XJQ1XybC3dQc)bjStJ!?mA(LG-)cVZ#dWFPughX zBP&5v>fzP+^BX_GWMN*FH1f_9Lv_(!yViKEI^>nDiQio5wS{&ku#p`2gp{6$d4IV@ z7xNr%@W>3gUxwI~3M-Hy4lv<78Y~3>ag=~NW$HTw2J&R}1SlU`o@4w7H9z0or9yUdokxayhBMav zB}aDSQJqZJAC`#oX^?0>qzD9MFdrzNLXwwWz0ySdRtHJi3rc$qIo}JpK_1^FKk~KO zxkwE$a6WPd(rzz>$lgM>OUiqyq8@>w928r&1NOYHv1A-LX4?ujSe%G&a@Y}6u0DAz z>3nL@Qiq}g$9zY6fw^#0&_t1(4V3`bM*}S+U`11F>eb1 zmPP}`sj`mR0x?Vwi3Ng5rnuA6xfz#-kSe?^0?!1lr%e5w4$Ec1HqgL{G>EGV{H|Gc!b%VvkJu-OMj_Oe(I~$^yP-wVd&1(I0u?7+) zhy2$bWj_I_j?y7I@ytdBg8aREiuhjHa`nP>c%%#>^Hdcw#oHz6J2N5MnIGO~1=@ye z4wZ+6OUk3QeP`0zd^~Noi!FD1ewePlDvA#Gn`-MhOpH)xA6WURi0Pcu0f(zurdj}j z1&pJCU+o1^0JLwXpz(pQzpp^oH~h7P1;Q|=UR;4%;g^`Gfp2wyqn&`aWdH(gi76NG z2n&QY^qX@Zzs&`@0l*GgkedXUD6?K;1KA)0Ge^NmT9BU>>fWGF@;~0d&E(->>(yWm zQ~#WX)^&ncG9er_J;xku`8Sl;00FNCdqzWEC&G%E24T?<;8oilEC??e!p6XBbYzJO z9=uFYk$N(G6YF^^r-Ofz;UWO-p8x6^os~Td@@EDAvu>hDX1PfY-XI52@L7OMphSuO zeh2KQ1x5S?XQ`oV?XrU=o4lngfM^%&#$a+Z>vircgEYz9idW0!(3R?clU0V=Q_vN1 zs10oq*~H3L6>L?-XIw`|(Tz&iAc{{E)(;G~9QQhM`)7iUPii+T5lD_DARptApLqH^ zGfWy@_FRG45{>+-_WCHn)Zv@)8!v6aUwUeYOpCrao=*L&y6}vNyb*n=@kMcjz43b) zVjs6j(~SurAXT#0jnRl*e0Z(eC^ruwQEhJjXu0AX+_r>_-vkpObfw z*#vzV%PJTQr!rD|0q-^GXj+(Hd{JeY0zYigcUJa(q6&3Lhl&gM zDS-x7e#ylEo(PVnu}0MN@?D2a;xh^x-K?tqA4O*#m1O$aLvpT za9=VrGc+?$GfOiyYpiME1yo!zw9U*4%*x72&C1FeKr}5gG%G7B&=zZIMyIB+rt$mZ zcMgYh_>Tkkd7tOGpXMw8eTSGbE8oW= zj#WUvP*&wXHs|_A8IBXbQJS?3Nu+ppQ&L?Vf5{uWXtubS7?HT_@BK*^E5b+%=Y(fS zxlSpZy6xjT-W#?HOl8ONn9=hZ^SYw%Hu&$k>f(^k-qGxkziTb3F6z~@cgJ!^D&F9? zsY(OtV)QHCKmDbn5W&d5WIA(vvsu%6{FmoD${l>to*^4@4yTk6(-<4y)kYrP*o@jd zcr^Uf;cq!_51bEXoolXxnm9_{l%dBnNUGpeEdCfJ#RW=dO_PDfmLA;x_lnS(MaaKo z86n?V>h~Y{M--h=`cugvEP7)&^b#QFjOVOcQ0w<0Cp zbk1kliINWm%Nt$ScN_^KBJ)qjKMOk$Pug<6dC8BA42xylnBc#}PZ$x@Br3+`(6s@J zvU*Vh>kL(2)|6(wG`kj2$#;1MQ`Wzj{;%3wKdq}`4G}VNg8wl?vEGAGq1B5AU17?X zwAy`E(LWVgc+ZCg`^p&|{Q@LGBoSBH4gMqeY4_xqnh$I&2??XlPZA?+-5_SpI3qzF z>0YeG2--lan>-j>2%`Mj05q>zYR@g;hmT2rQd?mEjj52n20pyUSLTQYf(5Q~(n>jj zA1FVI^g<7pC5dEIJvWq{0V-mIvIFy3xWzVRBmx4no!%eGO9T-v$`C8P2#Rad>`f9a zWnlSeo;ms5hbhJKqWAI$J$T)9nTGY!@ zd(MQ;;;y6(wIKx;#qs+a9y#7%Mnv8Kh`FLSjs({E&?q*sw))Lkvi4k5RwC)#h!e*r z`a;|#KIutGnSU4Ya(+Q3EhQ&x@nLmqq-#~m(9ZIJYcuYfQZ5KzDB*r81U}mnx_X?8 z>TDgZ**0F}@ae~@pf>HkZCO3Ur{>kR_)yJiRVHn?t!&PpIfLE8mD$ymggG&1FnOIE zy3RGyV|#m*5N_^bGUdz!7BOBgEh4Ff^OkBYmU^ghJI7cK@ZkzxiH;Q=rK6jBtw9D5 z7oqAQO>mVpyD~CtGNYKt8AHY9QEQM;su$QR4$Nq+ZBK{N6$JRAH^xodgwjyAdvdyw ziE`DJN zCQ>78wNqi-GM1TY3f*>oN%9iQRplMTXESWd_J76fUI?p+bm1z`8QfzrC#g5Ml9;rAB6IZ^Hv5uZ$i zY^Kcaxe?_f|ArsA1+%+nM80SMG5M}C%665*pGgoT)f@@hVNO*UG{PL)G?``?g%0V2va3bagBNo{Pp_{2^DMmPrjKu*{bpphJyd+&t3D?#vEdhi*B z5Gx@QX{X|q?l#J7mH1`vhBZWKvyAnX8N1k^BXmN_IA0;C6t4B*=kQ4HujW~as7D*F zs>&Cg<70l+WYRk4oNh1jULL!Z-Qx(~nC0KL-Rglm%*>S_{mrh`b=l1{FT}v0&tHVr z488qaTXgi`$5rjAGXD;7Ob*YNmdZVR{DwKUz4O$u-`pNAMn14wn8R87#DMI0)oT^} zD@l0%=M|@}nZbhANxSQ+uK|O=z7TUK@Q?n;CD%;oe8T-9Wec2(W5s6&H<5lbtJ|(= zae4I<m#2}_J>;%=49aZd#5`QI!8s;bwG zr>o!>WLRJA?{&+bS-!mZUuzPvH|pfFr$_!?Z{zwB6XvL+gc!8AovKiW3Fhz(J)hvE z`LKL<<=jf?@&{*q{{E0~Wd!aLmgcfc6zzEUv+YtIOI1HbyL-Ec6jvg>bO^kVaxTcG z_en;6!J*dGGVklRU7lRryqg|39#kbBWq0TlrK|0%uXkCNa<|^@npeiMJc8mBtaj-0 zo?>CehMxVKsTXMwQ(3^SXM=ko@?(E;%XT8oVhfRcy073HW0wCp0-~cnFUatm<)O@Q z`(`=s9~aLfH*K%+O5K$qn6GR1&iTfpmrB%3WIOa;DNyY;$oUfWRKU&Q?>?b0S|%~4 z$psRanX`J}>plJWBCj-^*+gak!u#U=@PA|*89V6?(xx*1U)SPS}d8{i%?3)N;+;0MnUegI{z%uO=W_NBAY)mwux@&`iEv1bdQk(KiPcIqbR6-Z!Al zC!(bzL2hUhs_1&%dFo;+Vx9&j zosiuqK?g_KA3b{XxEj`=zQ&`L^3g5TEtS%$F7r2Kj?FJ!8K`au=W|P!soa-O!|#2N zYjkp#f(mD2Z(4PrI>WLTJhqa0AxU*UD-3bq*JElXqHD&jE2dSHQ2NyFc&5VRAm`?_ z%GMOo@iUm?w{N61?R1jBgMmIkyB+G(DlMhBo6}LC_&~VP$CN8d6LqM zX{3-FDwLG^{VdIxLYVz>LIon#P+G0Bh%~|)yGn&ddAg9chG|&aQo7P$NMx2K5|Fiv zT@!LF%$~lAqY9_4hMhEch26Oe#PD+Cygjh4z!q4iq~3W}o(w@$b1PiXRS$?&&QVf& zV3)mizD<3DY7=xBb>Y3?d0n*;r7VErkD=(HJqv43gLyBsRucz@!9j z40uhSwE7Wa1-*kj=Y?=zV&5f&UIFH}0H(Pg1ldafUpI80)Lvrm6NHIXLT*oi-K}MT5Q3JjW^mgQe)3G(TM%(9DflwdI z-{iJXE(GR~Fj#?zyO(xi`a;M8E{u)wVMPhOl}$d%#h_1-_hrsid zEPGPa?+bx1`OD$_R(R9C(EycZjZz*{f?8En9vKDurRz#8?R0^wb5tgr)_f>eb@=@o zq*z*>tB~jQBgH;va>vRd+x)Uk$j#id*{PIs4r|sI9$DFOW}Bp}z;$^<6zbd0vh^B) zN3eWL2{OCwVcuAouR*q}tCZJOx{iu6s|vTyA~Lv0Aq1IOg2?2SuA4%nYf3YsLeo{H z>t+#Ve_>|HaYJ}%M|-J|iMXV|2$V8@m#lsmwqzCtDrK9F&B2x`TTXsIx)Lkp%`RNh z-*~73)==BB)Btlfz?L1PzWfY3|E%$(CDb(m>H@(musAOqZ|pfZ|5$rhH3V_9!AwR~ zXYIPNKDO^jiS&Ttb@f;{Pfkl2E8XUVe9~Tee&V&r6&Vm-s%Il=PPO(uPdIhJBZZ1A znsVFReZFA4@3CE_SXJKFc0Ln=3OJ1_3M?x^BL%|J)l~TNZiJUU(w7R0NI-kp$oXTk z#Ui*bQ+gwI33v!kHApBe!O1U8wqxK|@BJQa;K!XRXf5F0BKdqu!Z14kAhjb^wX;IN zQ30;n?Py2!f@FeUryXDA0vK1W9fU}&zimI42>v4mVJ&J>A(42l+uYYQ9+-sQ1Tf4A zn;a!gs;0I9R<-+=MsQ7-Znhg2V0RG&`5e$GCK4}?o8b^^w^12I?RIu8rUE z8;~wZxxeVUl^3su6eydplTFj0%xTcW05IG+o;+xI+}lMf2IUAD0}{uA4RWGV9E2&3 zXojzn?Gz|wQDIKpx4yg+-eW18K|_I=hbC45P`V>HJxFylA+*tAWXdvKK+mM@v^HNk1MscQPi?nR;F! zvs_vtz3*yKy9Qx>+9CLK^}Q9A&4%i0BMJA^ZhcV;J9MvwImZ|GE*c}28E*8okG$^P z6qGx9qjOPNN+2>L(BsUN3*zb5mq}%ruE>mp<2osAO~KurZ{!t!4F=A_QoaGq5^I*^ zoY^kGrq02c?Pi0A$j&o-exb!KyX;pRsZ(idHV;> ztA_P?4#3P${+G0hDd1X*K{8QFP9XiUbOG-B$?sze%$A&cGBw-S#%^I{ZEoXwHT^ty z$(Vc3##L%OfJcAKL|Wd5VU}j~PR z-#1Ow`9}+V$D~7~Qz0Kgf5-`j`FcCI|9%*hONyi2%6=!d^3uh54*%xEmFa9?Ec6M6 zVY4F@>dlq1cu)r}^yCuhRbT0(1AGZV&ST1axzbUdlJ8W0|btM9g*xFW>(54LuJd+=jv(H8Q@2*3|QRSNxlD1~VNB!>d z%Z<}lmZ9P|cIF3;$9{aD(NPvDMEx>Nb}N3ktHM8G3h5z7F36MTCs6W5<=dvttm#6S z#ngEORGtA*Fk8C31i|BdT%v>t447sw9_frc`XW365fHOm;Ig_k}EEO2Kal-27-od=4UIN}f0?PaH!eOX@3bHMF`0o*zEfv#VFA2R^-g zrMhRY_&$7SslGArb ziyci12&|b<>lqN_)Psa6Nkcc+!FA_!j_Rw%t8 zvIRnwi&o-@{uO_})=QzZr0vJu+c`T0W%mDIOO{}UR{OCCn> zFFww10OH4oUL9TfhG3Qx0*?LWl~3wHh59|41Ik8$rwb0SN={Ap05g~a>n_DHm<&3B z;r;RLA^*1=^P`BOp&o-dDLdrLfXZnr+tptd!gDb{Y>GwjsGRb)mt}cMxWsTRKgB7()@gk%W&X%D*@v!_ z%q|7-dSz`ZKV$UbaenSYs%KB1=bG?^PcEgzZd{glI{%5APG)0~erx-4qqr04`NB76 zKYQD+dMwE=N131UqR%t!R&I6c%bPC#&WSls1f^{1K&=x`GT1m1s(8rY$sqr+!uo$t z;e(LR>aX{Rr}^7o<0c+$P^$QITmJd<;nqCmNuK=c>~~lRZrz<-$hQ+VQG}B`upgPO zL&e}+1B-SI5%u9Xj%(n1UY#tu#(7bZGQUG#lf8IX@XWyH=V-?VyBf#D{{<-uT%sGO z>vVbxYnRwGzb%&rtX-BK=Xz`3I2}0fp@F+GX#KncXPuw>TUkH8eZXPfiC_y_K()|w z)m1r><`P@)Cf1hKZjJHm8Qp+!I}blq>DGhp(NQB6W^1XlW?G8R`r1}XPMKDgB9|H= zGJ-c@MoRIXQikIf7qrW8o^6s0gq?mGYUSO=E4B6MX2P*#J){Tgy=S-=n~GM4V+FHq zJw$PqIusxCksCUP4b~IjHW6KYFl)XCUyr@Kn~1`DbWg)-+*$~n>R3G}#n>h(rAY72 zwm0|?^zb}l%qI(1VvwE?YUMYk(qVk`%3iD+8i=%VZ%zQL+>$gd$By1nfmZI_)4(zR zbW52PRdgOa6x==SiuImVnc2?n$}mOeB~1@*2~%5Ymr+m)V3)R<2i|!?S2)47RnkKU z`O1{z0^{8E1VIVtgy+q2^;e&5;v?-+zw#m|PxlP>(5d#UNV^TLg%NfcW}24qE1{`^ z@$KVGIX!}2d(5$@TCAsM{>;_eMZ$@CVg$BbPX+x*QS_irt-^5+x9CC8pLsoVV8CbL z-|5_)A~war1R(1(Fd%R;XgJbD~7?y&!xKgszL=T{#5Wso%(OtSQ?8;3sdx4pPE z_HOO?UzW~?Eqi2vE=Hhc4rkvcb$UT%7@`?tn<#>-{U+~L|Sb5_yZIq>%1 zizGg=^Uh+etN%N$KN~&I;zW~i?l?A+e7+vxmI$guHm4BjROsOtvu(84gb^8)c7r0X z1z$yL)+$2j^664%g;tnmJcJV&WM17o#4ag_OEO5A)mnT;L64nc8n)OKM1-KBxH@y4 zCFCN*{}Pbj5o9a^hzLTjy0WLO{B^JzAI}^jH5%p4Exg!}96qsfTISU$LU;&)ydssH z?xaRWEBKa)(=F5%5z>EZCGMLaJFw3f;nv1SWQ%0(1H35j7M^-63y2jP5v8dkb%&z> zhc*Z@*a%WYZBn~I9?G%ATxGm6N{I;LGp$91L8No}>%ldTg4q9$EBbyFq0@Of+%5@x z^{?E_~+)J~o02z5T{K zf8A)pezxynMDZ~zl_#xEDfjPaw>oyB+U-8oK1!Kbdq~J}c+R~W)XDGcT+g9xxYFmtFpWk!u}Q)k zeF4|nYf1|y?Ntgzk~{0Mc8>p+T!P6R*Z;sa%bQ3UOB9@+TvP}|fXki^rQOq^c%M$& z9xModpC&%Cx?5*^v#DUoo@FozifM#{LX6&E)ZOUjmM=z&v72*x=rvqq?miF=nFJP?>E?M<)($B_ z|Mf6a@oeKZS)dw_9 ztl;B6D&>K$cX3A#)sqeqP@yx?2=D$ZY*xZ+_UYIG{s@S>34xO#>d>1b+KQ3{f!i9> zs$c>Xe={N5>vst%YPy4b&0p@F#ZOHpL0!acaSKYsh+L!quVBKR#z#=jvphxDRjEab z3T~Fb9WRi`Y^RJUAwy>YPM0!w-5=RK;&aq<--ZSMM57C-J@i2%f>X`MBA4pc{L?1) ztu|p#PWHIR@hg`GX5ce&pnQv&p^k zIfze=oa3;5n(D_M9u7&8Gvf>4k%zMP`3p}u5YNr;dRn+IrkFM!btT4WPIIMJSnB>l z(-+XcU3t)Ho3pYc;&mtJ8=Yp(Nso$LoYt3g`mZcwtpx6513b>>W!mA7ALFn=EBVef zZcnL=sl7pg2M%m2II`w@vgLI`?|#oPH@De0X^UUqK#&cQX(dD&0xo;#^=+X~`I zQIp$k4qij|S}(&HdGG~I%sieAijkE6~V8a(oo_^c+p7xfCTq7 zx|3-^DB;_w8rD~9sliG*R_YR{6B3{*&&cik^n5>lu4I6?D;hU?fA34J16&8+$>$XG za9Tryy7fT?z{9lCn-75Cf+VNr&_RDhqOixYnYJ<|)a*W7+zj(elHQll!qVUzh#(xh z?&t>~X?4hpF>S zXzbez^Yhciem&+NC=ecpdxiB5Sk>IW+IausWA8gU0S}hk)np9Qdj<+xlM4z<8;A)}~LZX9zq#d_ge_Z!zdsg1s#|msRe>n*!@|pbx@Fr z2UT=}_L^%gC2R!MDk1cX@DQ-Gm3J&hi#1r5w<(1>-Z624$9mB3g9bNsg#4+}JQAo2 zV^JYzGUv9dy8085#|bB1YN-i{M|be4KOu)qDYgy#lpg?o79_sm9on5rMFZ5lRKj)< zz4+zceOkhJ6+MpXm_jJ?~Lp9+o%`E^Y-V zU7^1f_5@S|Ve3PkJ7p6EVPSbaj(!jPn;-l9fO`QhWjP+I-5SU;V8l+HIy;M^h$-n4U`*coLor--h}>4_~V(&Ad2i*9oNR>GAh)Dl705Sau2k2y}j)9d^qu z-INC~(+}irDkE}1Vi)ffVt?bV=vsfL{O=AMM9T@&$zgXQrU~UaS;@xl3G}h!g~P{f zk$^)#`G^Ofhz3J8?gx%O;C0?$N4~X+4PyJ5<~}L6T~Z!BhansU?9hCNU3|M~VKEeB%t|QJ z5@f4}z|>OPc~k}8PE_IlA@`#d-_A9=EwRj&Wb;)4 zd#A>y3l@{M3SSR%cqIrL2OiFSr6tyb{ziRxBX(ijvIQO0x5DKsoi4?fJa!E32@MO2 zh>|{`%2xxrP`=)ur*l3ZHrzHS;9c@pfSi`8`2o_^OJFfsJwAR-kAo-KCFZ5?kSu7U z((;^t3-Br})GMJ>xa08`#;Mg857joEtBnzG(&C+q&FQWJP8ifRZ74Zy%dGte@5(KM zV5wdABvF2VZp>N>=tdqR9)j_!JzhFC_S!X<+6g&U^UGR~o5@n1MM0H?juU)mga{Wb zu6?BW8sR{|*VkvC#obl^P!OLy)ruTRmdAG4+SG0*V5A5+dQ1=xQ|723u2>M^u zUiRRF?Tkea-Vr+iLN!034%~ATxY7XF}TdDQb*;af>AqDsM&7TOb^27>;RAfC(( zpOi-Tw3I<^X)15`Ww&c`$1n~p?yp8bdcvy2TFAm~SyHp?b3s^AagW37glBQ6{~zyf ze#VB;oBY4B9Z33^G+B_J-n;>cDu5+N>C8&j2TZSTb^2xDVXn!`IKeq4Ug+Wkon4!L z$fJktF~lao{c>b+PKd|69yWG}laRK8t8@2z5CGpeEQjh&(_?Zzq(*N!_-$E&BgfbN z)W9}+GIR3Y=062Cf4|m0c&kZU?lnl7Y#vG(I(DKo$gS*+QtK$pqz$g5nO&h>Pf3pV zb)p%`!1Kgw)&y5Ayr|+fPkR`BESvt3fo?{XbL+Wa-WX8ide^SKtZ}Y8qU2kIb%BN1 zgTyEn%xuXe3b@L7i(gP2B^y}t9__+UpRLFfoef_5U4b)ct z;P5HR^cKpoddBph#NbbrV~u2m@s?6W?KI!o@o`r!)5f(|wlVYnr?om^R#2D&@A-Bu zp8_g?GX3_!iy(RqwCm!p{>jg~`)&Py0B3rEvwUG{zx|I*z#LS-i4x>A0JJWkc;$NJ zzCzd>g^v*YVN=G~#mDqpR`Tl183!jtRUJPybO)+zg)y-3_))y4yqu1>|+D@?n zJdRi((k&Pcv$%v!D@A7r*`iaUv-Y*b9`-l&G1<-G)FlDm*byTlo zaZmrgw1i(^0nIxTLwlp%?(>mCoo(XpdTu+MUei#v<#OhNDdmDg*&9<EK8&TmjE{leRgWNZ z*A6evjbH{Iw%a9emkbfRwB{wiPAMwbs@)TwOY7p(6d>ZL){)LfZ2NGz^{4yHmcZIw zYkTQws#01PxH_&gws8KD?f*Pnt^8&2lj5pao5`Py;1*gGjx(qnI`)sf=1>N8#=R7g z+6vM>X(gLy27Y~+mv454%!NI(b!hW;#{KB` zl9PkL*VA?+JC-$>^@L=7a<2KAllK=Vu_xy3J#Sf5HGF-*-Uy2*`|<2~gyMHU)`O3V zi~*0e4sxwS!n`N*X*#`@V&Dh#yuRVrOz3wbV%2uJZbzo$yY}{NXgl!g2Jdt34!Qww z=)7%yUx}W-umC76f_d@atTfr8tRBxCnUfwH@N+|$ANz4w*xle|&PM2~ESSTbRh1yj zRr8B8{Hsq~F-N#6uxl{v^7lRL(%UzxE8Je{J@o-yTBk7S%4pr=&LNHm)Cr>JL{;an zzkg=)G5^R*Y<`u^{84sf&9%)KO8)UvYp!1lkICAc*}UdSSQ&kzRdaMCeoOdyn{&4M zHR^y@mESitol2;gb6b}8{lG|EOl?_WgX=(TbHv_*0rzhB*gPm#I(esl-LUZ#|JN=^ z(d4`3Lf+nQ)qiJ}eR2%hJJ`AC5z?#f`j&s_wUaOui5?kaZolP(7u{J@E<{Rdoj#nmPra;rZt(k7~dE`CTlnRG7d!H6IB zP~E&mjX9~h4bcVj;u^&#PEIpzY+voG5mK(&GX-me7$qlDf)BIw|SxOJQiA}Ox_B?0dEKkal!Cg{$Df0 zx@U)sf8u=xnLTLVopXT`Snr>%r!M~fbvATA4_e!U4*JFv;5m{&0h%chiO~)jQyJ*s zx-}X5oof=NETe{^7wFKFDCV76cC~jk^;C7>9U)u6QuK?_ ztPFy*)UkljRBA6YJ3XT_72*d>#Viy}*~9g@>{!$flqRp~2T=5e1Zc%7wb0gL(NFU@ ze2^LFL`QZ-m(%jEnfmtUK7aPkKA+7Bqiw#VT~CWlKZiaY<@BMiAdsd*pwlDm!K|qW zMjiZ9lzqkJ-!9J;gjNUWc^8&+_5Heq(EbrAujag23bnAtnOk;FkcZMLQol-G=Uc;+ z6BZqw>aj`Q<=tzOE&&mTl78f{54pfuZ`+?mc`rMVsNV|md_ zI(vg_auB=i3&|lKCN4clM|BM>_x&U$NX=c;SrzJCZTO;_8E}U;kaCvZ6ekb@1by2) zXD1eFom7kOGa}enbmfnb_{dJI_NSN<12%tiz*{um}=ZAP=i)H;(IY~l0Tefdeh+m0I|t{+wtW#~_=mjzEBIT0`A#r-aRUqp#5Uc`;$>WELG z?sA`UQIV}44)K$_;TeR>ZGC!XEm{(tK~$J8_SuEpY>Ab< z0UhRkbKAbWFdcSxL0DkUwVF+DrP&5yB!>a+OEV|gU)Lb6Cqgfst@E7US z66blLYX%~kwqs6oHE#@4yGE^Us#V)O$h$pfV-a(6cqF?pw8?V^I;@X;Jb*67|0<$0 zzdtB6S;7q;=n-G?Hd^O-dr zN7tt205g^Uj{BX>)_h>1fvP156SnfPN~fDWrI7qB9@|(L&R{`VltzQzHM=5y>&`Of zClikTsfPjgwuo#M5EiYS)7gD&gFDdy^sfn{l%y`GZE+7F(MF^E`ERwne-l&bpo;4+-2sI7yYeznMnmE10A|HON6Hb$Q0d(6len zsOy(zu!VCzxl0OH-z-b6*z2kiM&G6L#rR{v7WGf>}%;!h6?r_>7m}?Av`3tGac-ni@hqk%L#5(roM?f zV8#?zudH5WbIPWj8SYgH)U36UQEr(P7TH?MlsvZW*|FaFQhxPgYIjG&Do>9&HxRQW z2}(vI%6_<-Fv2b=$*CNX+k?Y4LnkS1s_=n*d`$XR_NTUz&&+S3=$chq2sPY5X55O( zrENMJ%~WsTE^$Gc*?iH~^=ks&1+9lj3Gwq`tiy(J##c2yYAn>zZx+5R58562K5W5{ zOTVlnh^5DZp=RkY+;%^Ra#wD_Co2)$J<&+&odq?8DuhFh8W8{$eA*nwp`z6*PmBOo zWB<;NW`!Voj*OoDxftd^2T`GlHmR5(W6T<(c^ZhZvrR?{^G5o#)LJI1K$1p+a7Q=W zMlfLx5_R0tDK@pFZPE^+HVd0j-6CY45v>>D-igrf=0w;#5^0|V*9Sqn=2*PZK>vYQ z_CahPW@1hc?$2y`tzFoPH)q9yk+mFFC~02=Zge&{sqFkr82TGI;H zboLUz<;|;$rkzj^K>!<^)QJroRcY65vUo(28b+IF3TK*%<-s=DyXjbcDq@0et&=JE z+CbABZ&To;jT-@l2;m*txtICl_d!rKk08z0OkTxM6dX zSNrKZTKb@gX6B3rS>)eLw7HvrI|s0qMe-3Qk4KzZ)vN!M@a(gN=(#5 zeruGPoS-E~@!yr1wkPrFHGKL#6Wwegm4XyxBH0Y%Z{kxc`DXp;B^Ul=TfV6qph!S+ zua;`|H&%nRy#TdwnAS(61Sd|a+NXtLl0r+%6O;C+TLW=qq1l)SlF;Cnba3*AhEfTb z41Z7*ASF*t=~h#uB-=fFicm~xt(-)wZGWiD8x?C8kPs!_5g@h*(ApqP+iP%hwsN(N z+5Z=>woYbw#Ddm5t!*k_CS{_z^U>81OP8oGA1BcLB4mL_xfFsf)}W+3Sb-6%VXjPy zLysD9UnRJA%mE26mbx+VrV&HVLy|wCoK7Q?B*;V!)J+p1$U{bXC=Ry#PsBrIOO~03 z0JsojnF?9}i7VmxU&JHRRX+SUWC;(MuCdJ0Sa$Ii9(Tn^To)ei$CNR#dJX2D1gmFy z3>tAmOtW_r_fv#xdF1Ie6*PMWzBPu>-}CyL1`So`e?@lAF)=5#*iPf{JFNXTrp>O4 zYwcNto8rScT++9=HFHdx>GC@8mmpG_l)EbbB@Jjh5D@jwX6BI*H^;vD2jX(;a})fm6HQ%Wh+x z323stgHYSx#v?YpR-G{@o92~3-O_cbVH0aE0kk`6he4Br0a!lsmHx)W_|fV3BZ&e1 zo4#k5{;u1c7sR#d}OF`RyKK+yWV&fn0 zlgvgmK5e(wc`xX6mrw7UnGpOj(KSpP($Y#mjw6oRQEPVGQnKrcMu*6*WuezD$m(X>=x;$1h953_Ay#CM;c? zJ_-ot>Ffy;NAktyauY=fNI#h2SKi`(Kv352h-?$}Iq$J0NW}83;VdG9&m@`1Q^1ih zJ}Zw;_Q-JXg3>1V7;mWkb}d84w;Iw4EiO|(fmDf_-Z#f$REo|2JpHq2%{~(aElhoA zqTDpS%p=lv@}*anQPKn(qGp^AM?5P4UMXsucIs@LXJ2*~tDU-Ry|3E7_I%OGHyAgP$md+_Oj;83HczMG{=#qAR_6v#ue zh)*aH&)Wyuit5PSLzI%MNgz0HK@IgUTbIZY+bEW~^lZnrP2_1m$hFiw*8WHJlr$41 z4cMBtfui73x0{8#+RtV?vrtW)+sP*j`7VWON(@N36Ym-eQnR0}{cy-O*|bfrCZ~yQ z(bsHZOg71;L$P9OIG<|vRRxN{YfaWXF|N^wR%p;^TP%ur*ra5OJw5nY9_|OvwuVo* z$0JBW(P8Vc2SwPgJc~I=`H&pltygBqe`M@Oc4#nM-W*N^P2j;Ycva~TST+xlE<#mC zA>1^;b`ioA0y6{hJ0aMKFr))0s8<5>+!mJeyr-?rntoR$o! z3jQ%-UvUV`{T`M_@YMI*2zc;)T|w`G!h1TrU*S@S7#&Zb9ge24#h6yL!@KYiIbz6f z8Sdxc95WwCw}jdY_@ti_c1HqhPDNsluBB?t&RgL%HEAML^l77A@r4sVPIIgO5$f0D z_ZXPhHoKtS9fN{oHz4nt1%bh_xM*^V6I;1t6D_rbBo@JvPqZ-+ab}^SiCE%kZ_lTO zu^c8A(nr;l3GGQYHeH})=+zE274#Y{L->Nx@sT>AWlXTDEXHCU%&b|`LG%BcqBp%P z>!6u4B6qE!4+3;S7o`YTvAHmnbe>kEm7ZGMiN_0W4Eq* zcHi!#-2rKL00#PR%5Pd)n|9rE&({K|a~rt6q@D5-C=Vqq@72=uYWvAK0DP~e?N-xh ze_Z{Of9-UMX7dYbohdnGF+;&(q^+r`;MY{KbXDTD8_tsNX(@8^>)&_x__oyhHX_ru zC@o`OBCU%5IZjJyz}eQADEGj56@Wvf+7b1PY`KH3?6`YdOTpFIR;z6iSY!oglVh@} zRNH1~ZFiY$Hf3)6R8P$o`)^{^ZB$c}Pua$+ZP26&St9-UHrHQRR6joXjGriaOWmqV zHgP8FAeJpsyVXx&`;AEmU@gk{hELssmx}fs)ZmkzqlcKNz6;Q;8t6XN;2^}JyBgPV zBd1)0Ei?9yXuMY(K;)^An+JXa^LbAJYfHh7*t7pzcAh-X?my$zz_O(cM*K~23i~1`95^YYTwG}}eAC%}q%>sj zhqZ$e_VaMo(9Xl-Z&xfFU5lF}e_Q+LIS+sHE$@m6_eR1|E_7f5OOJ*d5x;)0eew1S8pYOf#TJwN zWczK;S^FZBd}Cq!%M;z@!au4aHn8u-Vpbu0Uv#ImQ~pU}jB6RM_>90+^iA!j$BWk# zM;}fqJ2bQgiLs>!hi#$TSFU@JwCL(bn%Bo$Zr3cRl|6ygr@ki3Sma`0^$sm#BkaPN zgH-#Ow9R7HF)!P|_^w}%fg4Q}WeL$gg}e)(QFm*x4N->y8=5i5sp3@m|p!;p_YYX?U~v7wvHyLB_p^$OFPhm!P^C8Wwf+?LEx>J z@>YWxlBb3U2J_B+IXj4V*w7`J8E%QJ@Q$qcO68y`C`*1Fvia$|0k7h3u>ja_6b99j zGp~Z|{RW0#gn$rpVwbB@b-MszG2+Y3Kd?lKe)f{qU&{fnw{DI-`@o(PIWE|MdHX*3 z=H?fz{}pC^P7TPd{VC}U9_Us7Ag~E*SEjIk+-D_?zB(zqeDRYm+y7 zPhgGrWc65+_ghFPyP~an(uzAN>bbvgym+gX*UMSOvIQ^4rT0A#1s2`+3@1IaQeH|D zL_6IQ%a$#C9{k<1t~QJm;~Y!g5$kk#5xasTO4g-0=cTS3@D&%q9xi;>ojEV)^{&j* z^-o3SS^VYOd-fd~pBB0XOJ+Jn6U>3CZ>#A;?)_25swCdSdk4pq= zjAwUeZl&kl!Me%?=KLh=0xfza*4wC9S+%gY3tHvbrhwgdI>*baoZnpE5$kdMprC4g zqgjAzeJF1P<#;+rch?C_5LC^(+&L2CAu|N`&r<~^#?1Q#{6C7$J)Y_Q{o^y+?1VYb zAJ-BT2$LK3B#%}~@mI#DW#Qn&g}snlKW z-+upp9{cOP>+^YE*YkS5eimc6D{n@O1DY(RjgPOsk((L+u@GpPKrPg@W0+ze{ z2`$A|{m+f3JGilShj~?@Y0rGK1a87a^#|y(__rn*yGh*-(r?`JF+JS(Z{`E{;Q$x^ z%xV0AKkNP`+|ab5g0kZGJuMA8?fD$@t8jhelimh0!%58(D8OQ|hJR)l-r7?1EC`TBz>= z0kV1TYKgY|oZqriKA83f5hs?KzVGA{nSP}yYD-4lEU;s;sz10$iMt>KJ4G#*q|^cp zZxZ})&kaf#8B+ZV1i5`^5+b5A+2H&X?-s4AI0{8gkY|8N!9Iw%QI5fFMWs`&T6rRC zgMo_k?h+Q1F_z})y4v`ToV6rG#%Yd02b+(}9fTiA%-Au&cy>9X0FfX{*1f@$+k6g$ zhq5Bf21GH=@e>*mIVnc{06&MLL}>Kvf9h=xo? znS7V}L5&zCm$5d>-!jY8PRJH&-zPvfq&?71MRSY=vs{bcZ?r?gR7M82VPVIh#X2I_ z2k;is_q$m$@27gaSwSDp^K8rR)*l$FbXW+No9*w?@{N&cmyv6}qd9HQ2o$Xw^uly;riN$Q zaDYq=^sUcw4QAmw=Qftj(?TydJ||FA=8X{dqtDd0q`mbVj9I^T^zgF{PPV#bZSr~m zw9J2=VzztdLl-lh(r6M(Cw+>K*DWCXq328%6daRDI@Hfq)x#ZyY@AHm#8ObNO>>8d z4uOu87XD>E8G2Uu^2!Ro(c@hx2h?K!z;GeJ6reRjiEc0^cw=UHhD zsmdYfqvnCddqH!IokVg*YkbbR3UfiB;Zo7ZNeu?LTjq^Z;#4t_P>Gq zWb84x$?`K)Tiy_v!$nPz_1q-~OYB~8OoMnIcnL@Bd-bks=y}XOkl;?sQ&|##Ii1A znz?*`$I6^h%^bMfb}w8NIFQr$c8DC>tDz#6-yI_4VubXD+lpb0(y4WO4Z8}{n|67M zXZ^olQ7a1@e9}RhW%>w33 znsfd5z(2TftEvAi<-}O7%Dy{M>;hrw+(P-DOjFkkXp-uDr;05&6L3GTIh`T~wWk*) zaGm$jDTNB`=n>NNMUL4&bd7I!j7l}8qkUAWQ>(?KQq9aZw3TixQS>HIzqZXx>vPR# zKdas5bED>BtFOg7cVbh1C!4*mAF2MdJb$=p%4~`};^jH;mi%%}ht?-LSEH!0avvn> zvm&b+nd4=i2Y>#K3BCL4`oNmq-M6EMt#)p_GGz8H?X79TxXb3x6WUG(vc`Tt)Bc0e z{-f0XkHC&G)K;l9-(sTc<`5*TmJg_45rtn?BJf`<`gX&I!ySI_L})G|y2$Vj31W%~ z|HyOrL<#z7g_u%B=lbJ1J7KtUQ8j?jktpO@-Wq2(h>6Wi^O%|e`9-f`Hk1NqQhJuDFKo_rp=#u|e zgdHS6xhNej0M21)&dfqF09e2cO?I>9-Rqjo7>xrL;!bpe6rEmIk`6wKFs)RA*(|fi zkD%Eo*w`cEQk1B(R`amPbhnYFf)4S=n172FMSh8sW@?lxA8_K2$A%2oGPxTH%wJm-G!9Xwcca2x+uu+WMnV9pyi9j2pQXugWaI3efCSB4M$5Zp$!}1 z?+{>j2-?eJ*cKbCPo<MT!Kq>n$od&OLQ$B=w^PJMQTTEXDQ}%X(`hY{%S&;N*aW z$YaYIR(^uie~h#l0BU|Vyk*o7)B;xvH0qPL`CTgny zfj z5OZ}2O#n#T;2FTYErwJ=Yyms;Jqi?&-M82x3jrpLdLp^(XXTC(a z`?dsrl>j?$*kUcu%&gsZazhe(7E&wnx&nndfT580kTa|BU0Af~jZ7c7k=YN!nS1b7 zl=B9?^TnAeqn+x{0y~W<6y<#Z?VxQ+Sl6{whrsABp54SswjMk83{X79N&xGHRPzC|`?<|5HR3Kq$6Ka_9 z=s*|Uy!W%%p;=t7n0=M0HJFUMUWAaKv|Z*jzc8^rH766kYkj!Tm|(;H=l#jB^R(Pa z%bXhH{gaw6xta$#oe2=OX`c3f?BhS!+NN=hr>wfw<4)GBebD-JlHJ^J9MYER?4Z4h zVYg2o2cy|)cI@}4V?S60q;9;!q{daY^GPY;4jaiRa^5a%8YUFp&fz_G=iQMMR8tUI zUqiTcs;Yr#?1py(!YF^+&YK zYs~=~?xK!qrfDS`;@ci0yaifwEUk`oq*ZbLT~xuMszot|Fi5jJ@-t8}$;TC+vwgPr zEuF_@Cfy~|D_wata|mB6ZYLR9*>Yp8=Ijk4``h}5h^riNjWop3c3UfE+qXUZ3n*k4 zdD}FH*hq%UBH_nU%R(g@PxQ-~%dmN&iAn_d&voNUu|_}2YetFsS#RPm(YUqHwq+K2 zi+?a30O}!nKLuR!`;ir7*bt#gm94b*`<8?+(93irh=M8Zgs%XghY5*$SQ_)QFsJ-Oz0Gj{kt0tP zuo1Db?%B&fVX%Z*=!g=(Fw0LGMc~7?4g;u!9%gH|y(O?9p`UWnMXPUj-hZA~9p^0O z0YN(hM|wA&O0vOzA*^W?uV|`|o+t=pyPi6}ID2zbRQ*X~tEg_5}smG^Joa*WNWqZwP)4jDH+E~;2mZi@1r;f8{PF!qC{mjz7ir44)6e%R!MPfM{I3{c$FY3bIZ}h$1s-F+BWP_qr0ug3 zxTw@xWU0m+V)k98R$M%i>5rIGAZ}%Box>n+(vclXYBvslUa@!YK?ki3u(775;ci|R zGpicNy@GK(^I}_%D5JYB*k=lUsuNcFjypktkLx3EvCqs`ow-g(%fAI{RA`(Lm*q<| zPNSfeDCl2GSRNUExUqtbLEI`ZsYD_AQD$R$&?frD85C?QTQVk7Ify~copC9LL1RAU zD&Oa|36qJB(0;)BUJK|kjrEnYWw#nNCs5GynD%M0)(J9Vvdb&O2yt729G5}AMPi;x zkQKA=0ajIKCi0!ge|Z9Ni|Hf32%Tjj4zo2%#So@aV??1*As$i?po6p1pGk|{Jdxd^Hq=p1VTP9z^eI#7^VHufzP?FAOH4WmUE!j7;DhUt?@ z@yN@k^3>C*KbXi)Ca#u=8!E)UR$wIrZ51B%$%y@4W0c*W`dyUDI)9Ii>uxF4{;xCD z(CI#~FR&$_ZC#?>+J5iEy;SF0sV?bnj6Qd{roV}E*FG`xX2YdMHQ~MX(&Ov@9r+@H z&v_D8JYa)tdkuAD>trpxU;YwAxFg;c_0QO#EAf*H>=+r|m`!ZKKu?iXIuN)M{PDd_ z#3#GT&J$o!2(9ud`0+`NZU8)lH3a8I*JOFH2{&)DtUH*Ng70gv2cs!#_fwr05vIG~ z?ST2OpUy5Q5RO^-bwf;CaWN8i2<h-$20c=;Qt}vyLQC+?(W}OKLD%U<7~EeqbljKWxTPiVHLCaTKiI$?(zHP|DG_qk{Y** zZSkVPbknuG|24o~b^m^U?>)OyeM);9Iq2vie2+M^%jK5+>A(@D_E8%xl-u>2WF+Jh ze6JhtFY(*^D=-cdAz{KpQBR6aZ97I!TBC2~_06j0zwizKw380Q8YK78U(AUUh~G6k z0a}~>z~^Q)Xa95b4*IZ#mvy@+Ym}`O3F@@cKn#l!F2ISmiYov_0pMS3@8Lj*jq2pz z0u<~&*+_n{%4S51qPK zeLL|FXQ6mIK7PZI{-de)9|uQpY2T);gikMqpLV~D{dmzJJoNI;>qq%H)g1}_p0ISk zOEr#(*OMREw-)JNjA`Rc+n2!2+_M%RPOFbNzT9;9^T(u=I>#5i^e)Lh^$g8OLi+Qb zeR0F~(eC?ZdfiO)KgRanD+R`!%2?AhIUV)XYj=rjx>3*HBF=T}I?a_U569vI?4OQ( zo!>3z#i$QBHeO*=+fr|(j>THhD$V)!hf1cqYVH!egzXM#W)?DJA=hKWH1o6)(5eXED{z zqZ#4L@{^Hl;=7Z{qGdTcnW=NZC~vrE#5}9C%ho()nD`Ezav%2&tsZFnUY1+#alR?{ zzK6zNxyr$j_4%JCx|-J&bd8vszg4rrTKHtp&gDI%aaE2+g0tM?5^j{{4RI z5oOngFuzYtE1Z}zYE)0 zMX93bqd$3u5%fSmlx?gIPdfn8(4x)w#pLTU5(nS;W+xW#>jIh0Emhe!KL}3l7#N8(;1)LF-5{G@6ufP*%_$NR4NzS|5!oa)@k=#fCbo0h!89GB^RDq;Z^6#mc-A8u(}1X? zyoPOpq)k=vJ4=R56rAa}Yi~mT9q(r}#toai|ESe;GV!~oX=6f8m*CWGac1icntl7c z4m|_5;F;EhfEew%S@6!7Mz^qOd~VrwzA71>xo?ZoEDTh!$s~}9-mYM1SJBNpAT>kZ zN``c{n`I{64aoJdS$@;qCp#Aic_YW>hV?&twHvo0H_Y3{Gm@KRHZA)nKlBh7B{T4a zXL3hzBgCaJEwi<0OZ15HwKpZQTvt9u^Vw@ukvs2{LZ_I&6%Ubym&cUOC77>3!JXOdqRdsP?=rTvh#c5?#5hIw$0VkDzg2~{ z&Ki~bT7hvME-2m+&oXBh4)E~dRTQho92?H3C^y8T)xAr z+-rZqlJzcqF}6*#Kt%Te)86(&0MJ`Q1x7PFS*N%mQJvH@8fw?{Rv43JbKzM%KTyr`M&j><7Q0Vi+K&hr zt!GZ}-=uOkN*972>9pxaocy28Ysf{SW#$B9$gR9U7(n3Yl-N;;^a2RY6}YA=$^{em z7aOOMlTjB%pgc?$G^-P6@~ltp{!Rp@nWgBW&eFW%_f+V8fEo`g!2&GkD%2pwYy=}9 z>gzl2K&Kf<(FptUW_Ysb99oVM*r!=ZQu=yN+^`mV7weoggW-Hv?R2Yxcoo|ZQ4@yQ za%vE-H2kQ#uc_CF$1teUx8E}CkDVU!*lOFB36X$*G~#f|svYJJEV$rTCF-@+km3D% zq12v7PJP8cIF7ebmS}B?(ut$}#F7fUP|RC2{j}dM-(PuV@SC zzFsw>XPx?v0~3f)$(&Iw9ZD#uj6X1RE>dErApa zVZd;W{PrUaTnrr%G7VEIz|I)U|=V`P}a-2vP5k=c)b4}=Xf?HXfhCXLb6W*CTT_ay#!0NyPd(IlXoPEj4B9<(sz0ICGy=yfha@mZTEL^I6x7Ap_Aq0$>u>chkmRHCj?9}v?G`l~ z<3Lf-AuFM6p}YAIxr>1=SNW1zD6e6Y=9HW78w63Cv$3nEzML!Sv#MeC0 zIseAu5#kYtE>r%f`3R~CP?Ua^UAvfY$rqhvnOw=)H)$FMoU9pktsn-rGq>{IH1{E1 z*)-nc`|vCpCpXzC1<|30*GJTTYhVxiNyH9--Z}}(NTUny3|aR~G8Z1*?iFs~cZO=F zqE;)slXYw=&Y6KffblGg^1vZ%?cG%ue$;0~eA@n$K>JaL(j>TnY%Ecpy8~aAl*+`& zEMAW)u8Z4!6>_^WevpwI!YkC))a5Yk5`^a&E>yfHKmfTFXgUtso}>lC@%5*^hLpK zN%_5LM0GK5n_c_v+|zy{FujlePiIq<#63;rtQ4xeV37TE$~Z%iK-Ws|(+ZsBGx{$1 zr}2DQpk$@GUmEXQE-w}enH;U4GeOCn{9DC5y5v$;A9&AfJ3Uur7Av5UJoLIyRz7zw z$?%LSw5CVNgs$hfCnP+VV@m@Q=J>XY+%B}l@r!dTMkCTr<6Bl%j+zzFR&yoc(5^x9LNn@17dYfb$U+@R9{6!cM@z@FFDq2n=Ko&^W2B=)2Bl;4-*j zmOnLx&6))RKZ3!Tuwr8kzsV}iqe5Lv=z51%l_~MykuBr)g~fe!K_*$mY_X0*=7=Ky&U+>OmMh}?{Sl# z(CNAv&2xpDY@SsQQGhlxfyX|oht6_0FY?a)fbh0tiSm`A0w zDYp+yLHwxuxmZYs9e89AcR_6SK8=67-Ryz*;RCf&StFR0E358?j1S`8kGj9IC<__| z)7hu~6Y`jLh4)Z5b_>BpUvTNUyzDgaF0o354hkjnH_>^1ZI`w{Au+7b$Q)1}i~kM9 zr_(@B_JW$EIDfQ;i%2a4U=YdV--!2MDq9cx@N+ZQ1FP`ikMC!R!MesUtPdou;f`*< zVVt7zyRl93Dma3s@_>ug_CjZ-@mVV!xsuJV*J5>S4RkG!8O_401Q@5vuXm2>^oRZu z9EMg4Xf;d4w{lmTk1+Mtf4E}ki8b}S>QZ@?*HJ|Y8~t_D$&;_)Ll8y!JUET<)7I07 zA%Q^=Rtg6zNTyT01pH52cdeO5JYWL@1R5)Pu&GE@iXZ%YH*ty=P6z>I zgCoe#5~e`Mj)DXJQ4aw`#!M8gZGHi|0=EqSfYrSk#S-3Tl*zC@Z?l3+R&ae*L7{1& zV_BfhsKF=_KLHJjK=Do&@i!B|7ikY7+Q8Y&k_4zmR-b^Ddp#8@b!GFm^>Hao$SyYM zg|hTnKR8cVj+ucrHR7g3e43>F;0v4T>I zQ8&c*tLDB=NI-7L1D4kPF>o`sjMi41Z5rnrF1^*yPhmhV|EbtQR)IT^TYGtK?8_7e zoomUBTm|hK)eg($#gQR%DE>hzZ))%L-y(s(U178m;LFs>obqps^jO+d7@y;jEV+{c z$0T6{(d0E$9Ie{`7In11D49^03=L`rfWo4ZH)>r%rL5ID!r6E!Jm*S>UtBcKqK zi&6K=mN#2%7}s0ad#v@03zNwjXCGs7axd?&)8OZ>=csM3{0I8}S4-3vu;YC`c<5=u z+7vi(YPnOkeQpC8v*tMVdF;u$_|Q{jN(^B=hJs|JZN zd}}m_j^=S*s$WVA_JQ&e-!za}yjB2zEFON=j_+WuDwZ3frg(v(yG}I=!&sM@&1QUmm!M4dd7OA-Lo5V7bLl&&so^%I zqs^=#-12gxTR<-V6i}X;TO7CR6jK#(A@OB!X*=O@#Hy>1a zDSsJS636&Mg6yK;Y&Jfaiq2j%;-c(gwCWGMLP_i2BW z#%Bg^9YZ}2BVeN1(&Af<{rL&?i}{22zmG2dA$lS#AP-J6=SxtL1n}=;-MaS{?W`*= z=TDx5jI>-*#nV9|7~p7BEu47LpzBVSYo%Z~o1rVdl%1&Sp=&uEzLSZ`U%k(d!DbF< zNX}KMymhW!dfr94C;Id}ao_wJs!e}Ye@)%trxpYW7v+jI)`fc7q>3+WT~jr{|!hYh|%Y!`MY|3!0awI}l!994;G~z zk44@P-Zmk4cwW6(hZoA?ZxGS{t7-532nfW^Dl;5J>u=U;3ZBg+1W2Tn~s zKy3Rr6;OENzj;_R%OwN6Yt0>Ymmjn`O-}>A+LK3L%)4c=oKDM+uU4i*3lEflvKh+v zX1v?eYYyF5{)zTDaJLjj zwZ61A#UVHFyL~gH1Fg(3?DtDXPOS6!gN-X=22M#=7{>A^@aU`29gJGf)MA4UpYySr zNB!DsEn`kc-PV1)Q@7ch!+LpiLHM^LnF|5`${DLXvn>z!A?tCP$f;k^9puzU@4knr zo~OX^f3doY4qpM-sq)_SsR)N&tGa07@YE#Q^Id(!-|m(8Ya#A4Vja0bTF2n%I=)ra z+Nj&@8Qnm;F)s?j=_y_>(5gH?rL5v9IegYlQ9ng35vdd~hDPL2Z${E>X8-hMpmALi$ zp0jcoVnkeUdqEB_b^F`*+M>%3lnm)wCgeb!7JfHF%}cGa0*FKUK$vA2`Aa}Tg)gwk zuB4-C{K7?n?kKTsb~ONcxTYFzfAK@I9&=u(m(ecQ=iU!Fs`Ivl)GJONBQ4bIMG}Vq zexX~wD`a}ki&NQuhj3swJKqj* z$J0co)18Q37JA#fh6&@06A@cO`p7xr#8s{m46gW->iWtMS*L;1Zx-guD+YF6jOJ~{ump;u5>c8FE zc^CSqBXMwCH+bOsV{9p@OUQljv<)g7bCN@^eJ6IAQMHH{6$@j;&ecf^hdatI2t4;d_VO4HG&h=;x2Q`y%h`aOTZz?*(p4>&v9>D2^tICUG3rK@eU-=qnVJ+wOZ`f~h-2S! z3$6aG;f=P*3P`Ih_8V5;<7^dzKV_3Hqza#r!dWUsc9NDP`YW*qz(2nZH z441A!i~P;F2IHMrbO);>ejV4Szym_$3?fbeTFvf=>Kq@YI!clpY?e7={q%z3WRN9# zh&4k9KT@7u!?j%z20YoDT#@a2xqgSwYc! z`k2uWD%vH#yR%3~-C!l{+5IJp4f)rXQbLSJHa^tRK0sPZ{G!6pc!u-_AK$cPRO17) zvpUr)mb;ODjlO0dwaP}%8NL^fc#OUo+o5l1Zpe|RLLy3MD!C-CrGU>-(@KI`8qKdR zzJ8dDh{)%oN89)$r;R1C$hYQQ%J{XIaGlbmNV%bm1exPC!?S)bpj*)G{PHkFk)_tGeBgf|8(`vwB~3K)??K zC^_e?f%}}yCdvm!CvglfuktJ*Z;GRlb9$GUd^~8P@cB)SUYDDEpT|VuN6}gRE{5F3 zJhFH}S7~rwBzK6)oJrJALG@t#@K@3ivE$nfXBYVd=|uUVoRq@-jWWXy6xd;BlJ-sl z&jLLG_W^v^UY-jUtznX0M|%+iDn-dii%f+3ccN~ssJe*VQsN%oE=(SmyL9NlNSY++ z7`oGPe-Ru;pNiW2bK7mMYvov|37`DyLY}@E} zPG@LTZfl`djQ^rl!I0S*FGCqVQtgd6U(4`s(Dko5##O($P*vQ0 z+!EMMs;KLI7(abh40akpBb|#~!XLJBj50sR%0}an2;5lgg|uRCVOYry5Z6ren&Hi* zL8p|OV)SyZwFB$i;t6AF>L@C{Fn@wjSO!5q`W@|J_0;L2xurpKd8}8}o2z8Kl;q-F z(K7tnzX*2Ou%U#sblD;7!w$eXY&B|XXzzDXp~NH=)n)AK`lUkf-7*d3_uQt~r@S-~ zL3qybYuY}7jx__@pKsy!?0vH}lUKg)yCT(5lvx*fKAE*r500Vx6ptl@8VDbqvfrOv z5~BbP+?tY~I<^P-@TKp=g~e3-QJ+&!R4U}5dO4wZ0u~OU8u)HEwZCFg8tEEOt+~c0 zq)oVsLJyv|wv{`*^wi3E9Eu*E^dlr=%d=Lt8>YB`?O931K0_Q;OPTvC*SCl<6JW+2 z^>dC-YTn#2xA7(l{-lu(#{m+A`cYS^uW9|_Sw%fcHJrT#ZYwipn2Y(2`7AKlQakDe%Rn}OJC{w zoe$dyQtV8Dtr^v_uh`%d*)X4j*Y{;AnyxFT#IV|Au{1(S=~Z&pUYqfW^AGvC%m`7# zU`I&^1GCt~k>rX=HTurTNAL8hx28an_+b0dv1?PHsgjSRrLS zW817sHN&4@G>8w&M|P;G|K%-z9n9GXP{*${5_8pW;fF?Yll(-bW2H$C`brgEr#AJ0 zCi~t5zUb;l<4+252zeO32& ztY4psPb+b>qRK6_oPU-Pl#lm5rnIXz04=mU2w=Ku&#_CB2znHcG1zXbwe7T&AmZ3F zW*qDbNfMyEtA?PgIH=?h*g)$m9Kxs+zv44{Z!JmsL~5ff$ZiPRxWaL&=Q^ZuoN{Fj zuA}*MDMlha)}a2lnfrBsvYs|(lS6SYE0ro8o1bqaw>0pM{u2Oop0?*vGO*x<16%a`u_q{4?xI5wG=(18JgyM)@nP zzKG0mQ27p*P~}r2)#Iq<9IDlT&SZhKj2`IJ13bfoTLB6!=P6b+b=P?gfmQiS2j)3L zA;rryeK;iAl#>^?!Z)R6EhErsh2rJAaC)a~!w_X-FttUcUL2ToA;@gpfLH}}J-WeS z!N*}phHCM_HYrS+sD!5C^IjY)Gnp-hQs=`l8RZb|D9D3TCJePp!5pHvX{NNi7X$Gi zMD>a%ml$J#KD8fbK{uy{i1onTv{NBuzd~^gp@}=TToncYJC7phIZ%%&t}~hMK}1AC zpNBxH9`t7($e9o%$lh$24y9T@=f19?_6CoI2&*@IL#R@7yjOX*6x0ltM??kJc~vuI zMS6Qy>QE}P*Qq9~Qk?C~oM`fm0P03)uopo8fmlWY)sT6?)chd@i8m(ZdZht96w)X{ z+8S}3L!9FGDx^dq*EvD!fsVkx(l{YsU~P}Ed#;Y3uVvCpoTfPb?UdeGeP_a0H$fiH z*y?;4OAqQ?OP#oOi z(^aJa7feb1z;RkMd(ryTe#IxpBE|hX&;iuuFuLvqTQDb9O-t9f9&D`4|IK0zTnE4=L6PXi1s z%*$8-U@$As;at8^;UZ2m* zp?M5Z>;Zyoi2pN*9D7FQt!PRlIC>0lZ>vip|g6$`OFW|DjVFX8ruHGIAXM2p+Ir%JU*4C2-cF&b<%%K*1 zuHH9;Ft|T{$%1$G^czyVUF>g)6I-@NC@X(D^nn()qtec!R5tk~Jgn!@EGD$tT!xbX zw>EGxmKL?{^9seG<4GmSG}8!STjavnT-!q@g+8&9Q*nk*v=ZPhmswmazkPdq^UJbs zBo!}uf>V@%+9`ON`2DC9Hxq(as@acoFi4Jg^Y(zVmL5w$`&_C0c%-dLzB?b8B2;fQ zr}}*dk~*Z>@^j8hoc}2=t~oQgd#n;1gDGkCl<%}*%nHRRHwxb)omv1EFJ6r16Y?!mG0m8xXawgXb;mwOJ(dlfZd)>cy$ z>_VM(;wq!^^F4~4SD4N^SOcEwN(f7&P~8cYdTK5k*i@(4doB3aqI&MHKXj=-$H7as z;jy*IDs>|nh?|nG?kn&LNX->oL#d#rS6>^Y=mIv)CW+j}k2&9W2d-S!DP_h$Q zroC!>XM)`!l-~b-N#pQ*X9dM$X>zacgXxFyAyc3^eUk0wZ6a_aEha0`99ktafx~Ek?@rgox_bh}P z1tc&Qz4B$`8Fln?>b>#t@f8zU8&dT?Jhcm)BHaG$H)pQ|NX(^X8>%~D9&UQ=+guP$ z7^N(r_7j*0XRixRKc(*u(L50l=T0EOWNz|#wrxwey#W{Bq?Mph@1FpN8>o}k1{z+i zI^ra?&J9uoi)1?~tmhh#pPbrlroVNoq9;#DOSIOfL0%!RBX4gc=HwNlhiQ zg2yRgD~_S3IiwX1zFs{-`SB3Wjo87F;yw~{MwZ#8juN1gQm8t;m?^2V1ZdGIwO3H4 z_PosMdTHxb*v64&RBGTBEuTZB1yl

    MV+Ya zE>816EB*V$>%y$k)ZLc?t@>OypfX(99OwFH%>*DZtzO4Rx(>y$pLM~`_+VQKb|v~? z^(iJiAB;%0ekj`Vvf5OrvE{KP;s6Ir=GJ)u-B)CrA9J_%QMb$j6T~(vq1%bPeS~_> z-e3c}te`#?BR02Xhdt(m5cc5=s)VFELGfUd%D0gn61RbmyR>&A{hcY+s5x7K*>rnpkj&hZ+~B$vB? z`1bUJ;J^LA1~MG`q`mGpzjBD<5I1gxvPVoM_DM7JWdviYy;)tNi;KiH_i}@U6%m4C$ z>?P?O?Z(VOEUikNG(g?F)<@m6Q{nB)F$YP1(G#IL^IQ45}0)XN5tox1T63`Q>>;OlC-9@9V?@IIR@87 zW#bEHcvB_?eX(;e(LB&|Nu=kkZi(Sq`l4GqzPl`@){-dp_0WP19v);HY>lYu!=OP8 z9m`O=vQt%`;Ft&e>Wcg89!uA=9T*f~-7TtV@0-mSGG+9@#sVedk)@Qb|wJ*6KQVrjp{eLD>7B)T2iaUk@ZVO>W<>YETXeCoH2MM)zkS zidy%k;S<;V^7Si?ALSx5FcN z?f6%%%vMh~Xym2sJ&J|r^81gSVp{hWnLZ=Tdpzk{GW_cAkV}loE>ZjX4N?EAw(nu( zKMS8eIApIa+JF1?}Km-)~(N|F9Hvibig(ZlE73veNZDdBkN~I<~+{KfA8qx+blz z-O4EKNLV@6lOPUHBxJvH@dXZo!h(=p57 zXKU%=E`f1S`4KnWH0)2RLF^>WO50a~nn$lk23l57u1vKyY3T;em;0gG7xZ>v3R)_C zDTBK!11Kuc>&H1Bg63rUaw;=9P3OY(Y}!|AgF+@z zfMzfw0?#puea>6O25~;27w=o_GV2fixzMyD{LO-23KM68-g&>l>P%z4hs}n}@}iM5 zS4(|%HDxN7r|UD#T>5uUEvazd3rP;P=9#5Fe%2Q?L#iHdcq|UCAA1yc!6|fj+IZdmCF4T-oI7ewiz*Iyw;6Lz?Yv&jqE5hC)2hTH z0G+w?AU8$HK{nTVL$c-Ao%Cz8TKZ8LA2Q zcjeb6uIjzlV;yZzHOkuA4z}3A^D(KtxQhrjOdxoTI7eluArmDkkdgo)WNkEIbYfix zln}goCYGTBv2HJd3B@|5IXzM%B3B1bQ&a7_C+l9Ni)#ocW*Z{(bkc6GfoIlQ83&2s^2$6zS#-CpdnQ4z!=g9^%D z48;WjTE9pZ zQn&?CUTx8g#>A&6LYh8!7?w*D6=kvz%zxcgYyQF{?EV@GvM3ihCnX|sdqDXA=Hhhb zV7%n~IcC@Fi6>b9<4bv33&;LKI}ZiBZbNt6rkjNiHZ_c2FU8%(U1g2U2H2nav`AZd z?}RjapBP?!q;0$JStz59ddP74uG7EfW>Jr8+!L{_mG}P=eYsRay4I#v7#<+~fu1oM6cNhdj zddhV{=}S;Or3)RgLIS*hX8@vc()RB7QU% zT{&B+_+~UuxuwxGyxe7fonMzZJB?2)F2-F>Lu6iQCF%!;gTAOy9{q0!QuK&-DNvCv zjjBr|0qrc97Vh75m^FN<{YAba%P-gZ0?%AkpMv|R>f`nofqZCUpS7!6iHGnhwC&|Q{UFzs zWr9uY^}cD1Uxrsm>}nGHX{1Z{hkbjfuoP`>T5YC*Hb<@YQ+`BcKqJ;(R9z^{5bQBW z?bv@-LW0P5F*WG$8*nxqJu>gw!<N-mb*_ETgYKDbms zkoroOR^~2xf{2`NmA9Dd?DDYbv@5OhjsTe~1z5{7(1?cS2JeF2>h(`^ZTn%vZvOuo zd#s0u`;GmEeFB@}r_(D;KJq}`0-}Wff^=mud{7Xy_0KZy*Z+-&Zwn6?N2P&@JKSr0 z;>T;Fn)M3X7HJKCq2?d{frB@MAa4*;khz(kZe?QZDXSb#Oq$y%GMBTq~WtH`kpihDi;>*{S@h#WB1~y*O$f>_h z+17HxoU_Fk``<7)vjQr>d{!tw7ao2Th840gLUo+eRQzJQgWuU>l&QzQ8Z2$f;d`Oo z7&}3u;~i5IFHR*?u$yv#T)x-+6!6glI(x=h;^#U-fW}NI+WDf>oy6fR@I$ATasA}jeg!* z?`NFqEVb~7{^*{oBXPvUB~;{_XU|~DIZb!NM^5l4D7>#* zCUeXB%E~;!7O}a4? z4ge&6Yr~&mc>#cAh#Fh8`6^j}A7h9AB)1ja)u*wE+tj#XvF+DLcC;E7AXue*h?4-x z!RjU4q#NN?XK$bU2VZC>UtDCCQq!n}avd&0v2A#YzM;K8Yz1Bk92`a1C8=3%F}UP6 zgTvz9QP0~sz+K0Dm5NfyAXVT7biW}G3DrJN9_ULik$!GpI$iMo!fm`Jlh}D5C=Xl?8$b ztXnD8Wm7iKwHSgLP1E6>fDX-bm@z80+Y9RiRJu$hSiB{m)yc2glYNbcC~R^*L)kud zsFa;BHfPtPbGT@3e-{TnjD3{0@sXGCQQ2?y+>OWQyew}D5(j}tzAYK-o!Gl~!6uFg z(0t&Lq>XlTGo=U2u@~66=y1~lYb+Iw6Q|ZigN)^!kiP6Gh0c3?n^u|`P)60lbe356FW6n z6K2$#{UoFVZp-hpmrsia>H7RL@^*CIntb!75=bO>t>d(j<8?nD07f@mASvYuo+x@b zoA~YSyBb}|eJUaN5RsHf*d?|RxtQ@rGO32V^hoSYa$Bih@I|>o#u2i z@OLBBD8RZ(-$xIUoA+ zNK#=Qk)^hM`V1Ek@khP|{?!AbQeB3Q-`)L?z!Rv##&&Vo#Nw%o6QMsGt|Tbf8<+-C z*9DT$@JAsP{|Eqc^4-|y#@Wgy<>>H920TxKGs#y&N!8vBwhE03d?wF z=Oj4S3bgNDncTp}x@>GF35W`TH3f+I40zTcMw{m*Csb(DsFT42$KrNB788u7iJutQ z?_!LR&*~63PEfidjs}tofpRw>uGJRDViR8j9moRoAkZPF+hN<34J9?AG^*A( zMfj}VUn;=YQ29Y(JXc*&P{v6W5DvVlh%P3mc3}&rRu8L*4Z4yg`o|ThQBv^sA3!_f z8og=Xa(5O@72K;Nd}EQa)fBGUgSjB;j6;jmjmx)Q5Bh5>!?5oDloJQ%7*--uwh{An z_z*f?H4;{E2pSz>O66U#7ZNE#|3c%P$o8D-nKzRg9n#L&OP? zAWM9>+du8Z`_B__f_M^;IQa)g&$hDs?GWT`S}@u{a`_NNXTR(6_lr;LZv&6s048~K zq~?OLq!)+yv5VBs0EG%l}9<`>Leb`3``r=a)r>5O6WMg@V+pckuwcAZEx0J^bn~;!Yw3>q+)9nqz zPP;|pG9H3KSWboeG|@}=E7;!gKysRZ5F`jFpSr;n;|#me+D(LtZKQWM z$kBCVQJK$AB?X%E{?fc3Q!cq^QaABv{2W34hAD?JDk%p@kR~3`Qb|4~KP{2+GM5)< zri@_{d>q7KOtv60v;6{}vg_;dcV=U-b)CNu68!Jk6M$(;qXB_Kt(OE%MPPfH zu*qk0O1$#fabWwJ8xASy1{&wK@WmNlZVoQY7H&B^T$8I_C*3Lyt zXkd37bwA0YVtg)O^yp=Oncz3HPPr;PHP*n}SMvP;Kuk7>nKNr{W7wl}Q1^tyb=@HebcvZm}IU%ZT zQe^0i*014Bb=`lurdaoLowzJ|*!IEHz5~^fO{r@uytVB(BA|g$)tYM-}aS_8)_xiS7^ZyUMyvwFk|`7Pl#f zB=MpLa8HE7riMy3$j$7%Ifu!y!m0c#tqmO^?f*;x-~N}(4&e)k4YgPgAlgIcfcniQ zL=6AQwB1T=O;?l1?BR>LlFom~x7prlI>)tN$a>ujInXtlT^%-{lK~wH1gTwbNQbF* z1#fUZYw=?i(L+Ex$&@!rUsPnyXQYxm3`e>{uW;2bi`knb0{r(P?xh&NpGwGAD)E}n z%%8?@=0$9Hth+4K5qVbe_{R9t!#KG-rxN9O5{qB^`-*K#x)Sm%ajfKl07%5mj4e~J zI9C;9<|*`u^k{J(Pe8R09gATaco73mVfKR zv-8~6RfWj)>OKhfi1(!{qr$S~(Jql$Pmv+{^y4m^Yuns!96g`dp4Di7RXQ_EXb-=> zYx+h%ePT67c_8?P%sQ!kThOJQ$rYEAE^d5qWBpzO?(nMOTa=wKH@@1Bw<|pawO3-Z z+asS>q*h2j5iVy96rPD0^2%)A^nmV&81j1RTUb|7T?z6|aVb=g9g(A2Wx_V=VQm+r|>8-M*w)v^_Djppvkp%h-ce9rEV z2NKI~`j{QJU3Me)E*O&Q*8W|$^3iH&U}W9eM?=nsm=KH0uU5>~?TBaoH@c%Xpl2p2 zcEwwg^P|Qaq~mCBjmaWrMN_N=!t?d_)+q0*^SP)WrmyTs2)v{1JL5g{6kF$hwngF= zSoziB!LkNTE`llvu!xSocqP%o=VE{#j3Pq!!2>2mOJ5NFcf5k`t-k)X@7}RK<;r-* zkm_&^L$CP~6Q~#Ms0)Dpg+;Eo19e8x4ReQUy}MoV61;nmNj2MsRGGC5gAiNiHKcg* zAi(r}cLPHQxVhfD_j?jT^XhaGGVqSVHaD=~OKZ%ECm$?O0XIn*v5YHo(+*x&7qJMB z>-tu6l-1K*xO-1alkwOU=-=?b!LLb(o43Ayco_Iv@xj`EGT@(>&~9xX-3O-jb@PgM zGuU}GeHGkyxi!#t>ey-Jjv9CUyf3WWa`%%M|3mb<$dy%IdG#)u04pS8lvKHcd4!(~ zE8MIS5913NrM38C{$&}yfq@326X=+pME}!kcVBv`u!z`kFYP@Y5Rw?;+Cw)!QKcK58*XrwZbz!A(w3>hrf`1*| zw%V&PYy!WsakkXjd3y2iiyYq(%P4-+x{wCOn9DHFm+dj|S+!55@)mM*cNEcq!jfYvod|0SdLIK7>c_z_c~<_JAJ2xHHeID;Xv` z|1$O&yybD=m+bQVzy&iOuFnvS#b zuUTj)F-<$XS3X~51iA`G%{{X8aCbL$P4-MHp?x0g&DdDWp9~>im`&(C;ObN&x*Izm zGt4}!tNHbNT848TKHY}0q=;2I08zxxPJH6AydxV75?l9_; zr|FLVa)BXJv;4%;SIy8#j`++x%#w}0627yXU7Ip#5&wpR391ky7VDy7O46p_ zDq5nV4f&byE6e(A$haaWOwzMcT(Pk-sJ-gRvgfJwp9gtd$QsP7e~W2 zEOkRSk{ac^Ow`@=A_M+Q5!k2*!E70juR`jOyW86wKXD@c2h=EY29TvMU56?LA~Hl0 z+e)7EvN8dD*SLQDMZU~uGz{^(E+plTTC2mu=K9vx5RyY*)rsN~TE5Hte%<^~uLOhq zpI@)!ozUbxn%J;oo<^MEw_@%w>)_8voc>l(gB47@%K|OhAq;=6ro+8dbvJBC@niQ% zTiPf51c_k4R0Zv{S#WvZrn9DX)NRpA%RHR^W>u}7$DC$Z7pv0kn^!1+NJ>K+_0BxR zr$&uPFPuI;SRQ`9w4|1;ohDpbl>JDj*2d%RVh#d^DBRwvs6jnWBwRAoZ78uam**XJ zGOhfJ$D_7C)noeSBT6fQjgf!ms{7}`=i=FQnZ`UqbJC$z$rxlAr^!k79e*fyyyok^ zMbaJUyMWIND@qvYk$T3k&s-&HuOW~4;lxm7Q52$Nrl#(}D%q)wlA4*pk?R7BkqbrQ z*H}zwvVjHmj>jO=v^0G8>`>3PRoR*3D#A!~nQi?-J?w*o@R~Jp<@WM(3f(;xENTcp zej;J}k{;I!kZ&Ki>rEFP#`cq>w@35YBE6n;L>T7mRAOJy65%dl32zfv zJZyryCKTWeUsEq>3DJNfg#yHNVQAb2adI}DhADl!dMVR$ygq#2&f>MbtqA5+PtXf$MBp^Bj(L%m{%)4y|5<>#Yg|vb#jrZC)(%x6y1Q(V za_fgVW&`VB)x(LCX^kiM)QJU_jvdKr#*$oY?L?*Tv&2@RGp39(@2ZHv6vb-jK3ev7 zniIRlzjR$WZ4b9xW*Jh&92`ey?+c#PaDWzMA zGC>lXmdU+6#l5cU2Sc!zCP5T>3H+K7<@d$ zAiFi}`-c%-T)19^8091MT({g#cfAoV9Y&0~BKjGqX-2SLDRTHv&*4<|3@hho(z0VT z$yJSqdpUBkYJD%??Y9!guUr%!v~Gs$SOf7r2|&%FkzbV(o!0HH%60t_c&7*v8WCn* zux@JMGhD>0%W+pVOWBsrVYlbu2c_1pxh^nE#PwAhzRqM_Vr;pn^uD8nw}!w7p)g%3 zW!jWKtD@gwg`L;B52z3W8o$()iCvAW4lb1b>s7o%^Y9p3+ZwH{P=?sT zOf73eD2*vhAncYgMG1h1G3$>xr)TvN#sHxTxD9IoT5H5h3bN3UY>q zSp5buDuSN_#KAxwBP`TS7HWvgStvpM_Qrj39yQNKmAxzfkA`~9MGb00|K?UK8EqCt zHe*)E-u|q=DEk>H$k(h*GbOY?&^CWi3jd}!Pc)JTSbW z!2B6t*I9)_rov7Y{NfkS`&h(jF5P5S3;L5^0-!7!BKON-owyFPWnronh6iL(h%Ylh3h- zD5Zt~Y@M-T>q^_G1?W*$$$g@K2^=L24QyH&pwuAWpNq@1#R&YwNn?}*ZnP!QUSR84X~9>MkE)>S(rteM6<&Az03b9_@hC z?QAbsLQYk?WeUyeG}zu_%o@85j#bGFS_*4`P$mZ327jvDd5C$R;;1#68qpEtt3 zm~4)i$SXqx%qu52rA>b(=GyvEmSeX(VsqXw7`kcitHZQ&1wY1ofoqJQNv>$iHLJA{ z#9ls1FGLK@dyig+qo}CgP4V!=O`5B+1MP@e04mcA;+pwe7UC`o^%`)=`YIx%2!Y`s z^UlNKLy#SmxSu!L+(xe((o+1R$mq(D1wLwXl=EwfbJw-LJ3#Z#-sVZGXFtpR_{5oW zO4&ITY|8|!!&LqOxY@0G{hAX_PZtoA;G)~^@dk}s7bU)1h-}y1xU5C|o^k2+J@Msv z(9#(Vtd{}(k#V7G9^QiH>bdY2ZHTK%*Yhew1|3m^UEzKrt4jp$(ZKr5UaA(}qJ^E; zz$GTDy{b%W2EaiHUZVnMX)F_2;6#Q+lInJ%8N*P4vnjBY2~o-9#<50*BR=;;_rb-# zR$g`nvz%cKO32axa0_L-Ki)6wF65{NS!hfyIEkp%5;lpg(@StGpCOunb*lurF=NR- zFzmjus(}S-(n5=N)V(m`4*{$vDX>=*6k#acd#Xz*w61B?b(SDvo*>*ux){%cR*rWY z3j^0pK}{55QxhU^0Vb%l4PwI@jX_Ey+*fx<)eBW?%C4KVdxT3VMGQpJa!j~Ad}bbM z7PmkBKbC0y6TrsV2uLyl-#LKk&WHgnT*bYxpMtn(g!LpL=HQz=URw1Vt?qt-m+p3Q z;a2B7{@I^7W3M3vffSfc)AqHw;>Z%j6m@2%Ufk+sI-K@Vug zDwjUAtEFw&GQ3SqGh=xx`Rw06^v#eXN1d)G7`<-+;3)%$ zA4_dM@ln!R)UVuTqzqx=iu5d-cjg7Y4B5A2v|?XJ+!=G4*8b3~%$=nmRGqLm4Xpk7 zI+F}FW4U1jjT|xF`!9LyBPPe(0J0!LO;c=sHKP6kc&)x3{@cPay)bGzc7Esw7gE3> zw%U9Xynq6R3_|L-utXtv>4$#E1uZmP1l>b}VW?=w7l;?!#G`=JNe$G{O=|iOBy@%q z=lPd*Anq`_17Fx?f`yfBsr7)g$>ANwnytbL#6cB8MMVd`#}@2Jdv3x;Q=_&hwm#5e z53_8&E%9p`>w}+lUlc*zN1&}jt7|M6kDC;C6=Iq{q!L48IuO5^Y(6V(E;2KfXu#~x zKpMqj?L2(HG3%5De$A|}XPHrJc%=sMBO3MYT6Px&3guloghiaEpsu44(p-4x-n}zs z$k-&(8?Bmz?*B`qMIA1@!>8)e9^W$_-<o|%68zyqw?Pf!pr1Gn=K@Xf`qx$Rs?NNtdA!fgW|2}jtTKneg@Tfhn%w9HVz}&t z6MgIPiA0yp8S$^3CF!)cg9yjCd3e?=^zc*o0mHKE8u*3zCkK%Ot>4+X)2F`d@EV-2 zybb7o&qqNrQGYXxfPBaIM(5mCFaXLkvVTU9B;}PV~k6KJol_-*vYE-p7x?moNp-uOPa#d#;!Y zG2Ks3NZnL&v+fMPEOKk^LzMqLdY^`%=Ft5Lo?oOO{qDi0w6F_+!MxcR&(DprrKBbl z1`)z@?NP$#GFx^6;hWGvnzC`N2~sl{<(A>%TzT;NTbFd{?wr|9}Uh&!wb-Li6ThP{<1B zgx6hcebDR2qXDIYz+V`5|0uWiFG1w}Vcxevx9--9ySWbcOQh7+Z^V??H)zif}~NKu+5cW6}bW@l_mK-pizL$)t!TdudoKdiTn zJm0cfWvH?9bn8iDq#Cl%(9-^DW(cMq-E%%NH_1HP&s|#R+jzctphW-Z7w3o%Jo9Mi zfsToj(c2gX#qI^H|Hru%DR&ECp=(c&FXT-74Z(L6O%Cw;^FIa2Jea-7x`LtTU-mrr zS%0u|9XszotC^Thx?H?)F8o2NcgC%vA&+pyQ8#r*wBjO&;D~;Ft_Kt3!c*rL! z^B-r7pMBVy8c!pCh7fU)9SZj&C*H+GG4k&=t|Pn8)P#hT%pA*tC$=|kA?w_ZL^@T6 z9pPjO!Wtuc*qBq1I5B82tF)SVI6Fk=N(|rk^&MfUxIEttTTFv>VGEqZAYx?c^^TEH z)!CV2IZjb7jT~~d0Urvd{&AFZATZ3%Zu{n$W9#68f<`WPY$IkP4F-0|)-1fk6i}+e z9x)IyLs(=8vpq5_P9>SHzAycJ#ixoLKsK+)T|H-8Z$A811@_Cl8iP zw(Q+CRo{}6sdjC7U9@#0F)rTyAugrxrGbzMDCuZnSExIhvmXVXo8dnSgdRKEaHGp( z#~F{|&%Cm;C&;Qw_51Tr;k%GtDfi)W(g-&Y;!r9P=vP!UxEB#pavNm?L4f!yDJy@% zqdQ*7cHEjW=Kt?6>p{-)PsvhU!Es(GY?xFUPko;g7vj+3mlG7)(UhDLICLyCmI`i3 z;=mxd4K?bGo8t{S`ReVu#-)yQTZeL@_S!hZfC{8-xkV7?j|@EkL}=#R%T41#6&L>R#9g=W9bY-ZqBQ zLlWuCU&Pi)N(ufxLmuEa1yAIM(x024D>+oSLTPB3m=A?#IOAl>#MP?&jz}qu$jy~D z_pn{I7-lg8lnHN2Vu$N%iwNe2a2Mfty=RDpgU^j{^)Yovj+TRwjfNl7yL4_L!yF3f zVN`oe4d?4LI{4eT+J2!svT%XW$OpL~_gK4m-$gevjyT(o!#y3p5MG#Dw{T3iNUv7B zf?;T0g1=N`<64^(1FeByFdm7woE?ASXd&%rN&0>B7AQbQ*`=#C_o?CZvz4nx#~l(x zA?O%@%ndYN9SqguFR8+P3`|v0F^$--khv#~S3P>wNqm2PXls6BZKgR0yc2!eD?nT) zD3g#61|o_S7-adwT+$4{Dlnw;c*S(Y_F<0ceykfZPxJ6h7IXMsO=HcDd=&mVKb+Ap z0lxuuZZY3k`+q1=W~<)N2Tb8BBq6osDm~#2fVXxOh}qEo)NN&`Rd5~QP|hNaH0F+4 zv1=S5919`FPlYpg0UK_%93fvskJw7aYlh2<&wS|_T6Oa2&8%ii@*sbRnY)yTDxK8h zZ)=M}Fy?ylt4i{p0mDze4v&}5erjb`SFP4~)cw?Wmzcc<@*6+#agJ7kGe8l6S0wP8 z)Hghybwr?4pm}#%La1bgdvJ6P!ogbHRg#=}VBt}U$|CEWHjJrYB6F0}doI%<0lAUz z)NYB>Jqlz?ksAI8YKePPapq+%wIPXY6bU&*K7sYhrPPaM!5{=JCg<`8nH2^8e< zyf%B3hCff~^78A23G%HnI=c~G!M!y_7b8(uC(pT;v)3mBDzQgIj|e!5h0PI_Re>UW zm3v~X`AB%(jdJk8Kb>_c->szP>@6>zxaNO7T@FW~&MS`rm#jHGz3fCDh(cm2y!p_u z2B9UwtcGn-QLPIBV0vg{`^HzT*zMWlo+Mxm$z9@b@WwEHS&4O2c4cE)QJW79E{cv95#4Tq=O0kuw%(T*VPuSy{a; zZCrW*9Zow3vi4q}*S%0G6vI~(9KR4kEoY*Gq*EmDd@(Xbg!8w?L%Rl?kACFdwTM7l zaH_O$K6=PQK4cYHCb79N3!;}|6dA%0^pB%pdI7tN2hAn4{(rC&8xdnQjeCI~W;C3y z&O}0OyPAjGg5))+eCLV|;oO#GwS=WBaVegP#G#pnDUx93Gr5pC5~v0Gt_+A(Cex4u%lKy%rkk8X=FI@JG*zW2NTvk2j=+ z0jq~|x-&mF>R}z4A@^^}!6f?cM30YWBV|+YcBa7Bz{@cMtAu#%6LXu&Na9!jwI#UF zwwo-gAB;x>|JQh){Jy*Dhvt54>EvAja9Pza{j(aCUlyp8!h(ezl}&8lQ?zpA2Q0*e$ANw6PFezySqBEHXhVQR6POx zkU75gnF36I$MGJ`Rl}2(3N7&a8K>N$nToSPQyINRsT;SlI%!Tqx~ky|kGrV=&PQ!; zX(N2%7ho4ka`A5#!`-*7Ip6SY8ZS|SR~Q-*8+W}I8^vIAbA*EZT@9T7PeiyvLGp|i zgf5L~r(5rdC{0N4#qJ-xzFQ)A`J1raLYXIJ0UE$;i~d0Z@@ogyn}_{tdLxzfwZ?5ZZ7BZMl5?Gc6#({RYB+4IA(XQSsF=oE#Os!+TwR}s>;eilzCD5gql7n@uGwCrY+-4mTVJ< z3y`sumTb|@;CX4bsWRmaC_$y?O-kcuWjO=*jPjd5&RRY;+ZWJ59!r)m3r;qgnM25P zBWUyqXoL1BUn|Z0H=u$pzf*W~f_^HSB}?O>_I`6e!?w>Fa6AK(myiN;7*?4SyA0Op zWUi`e5u7y-z6nLw>mVgO@V}%VU*ySn9LLgtsti_@G!rIO!8e$|>p90au&T-#ki2<$ zo~F8F9&A=0w}i+Fa|hD6;5dpbg?>s!mv{ZEXeyEv`Z*+-PNxD+EM+&#GDl@XGpI+Z zQbrNa8R6{#c8yGEw5IAnbN{UXs~jFVSJp7x|Ch;5iiehHhwb_l6?aB*3&VrIAFC?Qc4%dRyMf7+|ZRh;4{=H&!^wkuxs+c>7%k7wBscv_)?2BodUX?ViO_~eOVMaPKwhDdr(+v zUpqKS5!^QzlQs`gFu)%UqmqTvH7rp0Xm9~%>F_EKsUP3cpCOH}aq);b4EFm>r-9qwL#d$d*DZhbMQq2^J2>+z~ErfmF_bGgk?Y@WtU8!Ap5;p&2;kLkd}q z;K5){p^VZ9;;2}TC5N1P+ag5kurFDWoHpND5LYNoNBjBQ1m_g`hTd?F*2+B2V%O?r zwwAtyqt3y$P2n(b39B+DS{hLZ;+f^%B<@u7St9OptAKAe#$toKT}DEf^}4 z*&&SaW3?G%C&bviAw>HENDQA(5YN59d(o#Ugfuh6K+_I4%4S}_y8e@)fFe}?v7KHae>JHeDc|Je{V zAZ?(6vS!U=qu}4q$>uVU!E3<@S~GBVf77eu_eiHUSXgzTAr%8wnlW%O+HS`zJZGTl z0L5xg;q^TR@Pz;?$%+2MQyw>&_F2&G(Ldm|b`9sxgDdCHUiXwGjdq_YK6Ahd z&M{EykqENI${j1E>3p|PcN~`|OJ(3ZfwFY1EPVXomRT7vTR#5X_9t8utKvok$u?|~ zy}r9CjAMx^dlbo&hL5+#+c)|CCUao6`v&qTYaaXfW~|f6a_En3ib2@};6G!R4o5&i zSxQh$sb=47RZ3~q_VTKb0E=Kgzy!6ypP%zSsEtEb3gR{=cV4t3+n2@ovm0P{f@K}&dv9gS2qd*vuJGrVnP#zjy zUvOTQF-oqI60;^{#_rd1?pBwZYEw-x=2dVu02wov)B>f`y=5p5FwqrMb0aFfrJ?Ji z6ey78GGyxmpabyKa|}o}<(7;0@k)i&SxnXOYAYV8Dyq3^ifQn1wN4Sq_;hd%B|2^3 zXvV1H#H8#7AF6AEkR;F1^q;?9U;A@d=Ir$y%5!x~0q{@PBiD!%3 zURefDriqd7(2!L@kX+K5z4v8#^PrbMN#|Iyi~;`Y1H;34zwn_+mruyzdEg}eu~)Dc znI_qp;(mS5>us&%$Gx2!lrrv+eD+~@&d<{Huc&l|?1%DY?0J;`0a>sL94!L(?4yJL zK&I(C#}D3knyGL8{UGei!m@ zjhL0vRlLd&?z>1HWXq_$c+%{5SMB1sE;;*Enb))Bha2&EKbuD>6S5!|wq`y4&I}cp zFG!iTFQZ=J)+NVmnvjAb84VwkL1uBA%>##-ecNtO`eNlZ8EEe~i=1)9HKdZAmvG2* z`Qsn9uPC@sxGYU*6YPsi=eA81CoFS*$K2y$F9lIkuIyG7Y=z38lhQSlB0BFS;HpJL zfNUQ@nu?YNkGjTru1`Q;<#tzDYy~f$wXCrNpPZLMcm5Lp5kz_mIyuOXBgt1i1t(2{ z#*5ZGxeS(tf|pOOf`L!R3E$*$Wpn5Y_x>>Vb%B#uwaHrPoqzK(DK~kV<~od3zW8-tpgevQ zA#AQGnRL(ITGIzTlE;zHL^SZ}r&7v~Pf!p!wvbec)ssSUm;89)zfzN?^EstV5t-O1EOAalB>U;dL_} z;Qvt|?_zwGma1<%WrHIXA_V=AVE98gBw^_n&peZd&*99w!r*2?dwr$bCSC zk)!6D*`vFlRb^G7zs%zOzbEhceJz*cX8xY zb3312VP;F>-+&U0Qdjg{M1hlP?rK?x^<{sDSdJ_yKsJ&a-RdMSE9E@q-MqI_Cg5yK z6*hto6mdo`3?%|b5!jd65U&wO{(qHN3v{`>mp2CR81RyICs zpTvOdpvZ9z#p#-=+vD_7x>XmeYJ?W$-`Qp7?^eqzsUo?vzvt9Hf4P5ObfUm%m8P0 zM<>r(ZkRMbz+3GeE&DYhUIET2rzBHA@e~lUCOEy&D_aF-Gh*2DB82D?h6KdSi?Ndy z{s2%E4-_Q?g>fzYS(XoDGx!mf!>=UXXr`YqcI3PI^X~^@e#2};zyH?83f-gGq#Tc&G^(XS}cC^G~<-?tV@_WnENQ zuPk+~IYLj%M;3vLl+N4>X|QjG+Jy6&9RU|C{&Xh>#7P#RA5x>y%LOfqg7hb+&1F+$n!`kZVKw+3c>QXYl?g!nw@;2FzaU+*7ZVPUNA>9qz~6^3UdWd|p}G z%*Z_}g?*GYkOp>4eLihVjL9}Pn;vki+~|?|Az8x}cs%=<-JAWjsIm6XbCZ{H?i@Ic zO^Hp@@bo#>ScH4~pdRjJP_@>2-BKnayx%C6lFh+^`RR3xK7MG7*Bz3xVC}S`GCq*j z*NWIZqzH}Q9S^7_1Qw_tOB!=O`!M7_&u3+ojE*ofw<^5Mm z@4L^KWsAcjqkRvUU4r^4pl8(vJ`)-$w?nwXHS#vjOk0^3%MW>nR^u67Z2WR#}ot;&>!m}`U9n#W;IX!y# zjk^wK`G|Suf~4AQP~}_KH2M_#!2N1B)|J^c*%z~QXi^XNB%En=bGvgjx5i_reL8=; zWzl4mdsp-H&l@_g=wmz%KASG~ytMT6OHH{}>#Oj-)a*OzV$RvKa&3Z2YCVtm>B2#@ zIwcUgCO|(5Ni5T_*B7c9Bedy&Vz7N$7zS*WRL3H&cjRCe@N#a;D3yoLh3e$E)?^xa z-NeF3gm>OU|ChyWaY6GvmNu(jhGAjORX9DwzKV4?F6j4}CkgA80G|#$-gMz|{i+9p zAu-Moj;j*sB+LxQy-rhRMN_QmgQnFYlVa$b^9P5Qoe!9pNUt+sYqd3;qX_11p_i@8 zkSTXGMad71r{9mo*91(`^?nIuo1QiPDjxlcr(ORpS-Iwan#$*E3fWkoy+Sm>!WFs{ z4-wJ~*ip)*nJ?AOOFzz!$N2Pe-R>9cEB{i>=$aqD=XP}wc;Bs6I8kMLR=ML*prK4s zyTdD$Jq8ka%hatw2?Mi-`)O2>UQ!~+M}~=yV4zk zRQF5c_gzIN*9LS??)b6weM?g88FL5(agZSkDArMPjjhC?@*$rxojGh|GyWzz!XsNy z8(b{G|8~-_!aUoX`&3UF;__F_Fx7kO?vlTfsEP)5Lr$Hu(_flM=Ia-Z6|AmseYxIU zm5f;Km`A$IK)ULr7OAWdYe!{Qc{57SHV((6pvp0B@>1z2;qe&aT4(&In5lLe-e81_94_^oW+!9J0 zGJ$JIQK;frNzjgp!?;Q+GL^N7aHiMA!@r~>(K-*m*;(dM#J;^^bE|C^H{4y4IG85B z7pJp_con>;-QF%C9`p@ou%=+iUj=EW7bohAnWL#LR<^x#uzT$I`Q=~5=IYCkTSa>{ z+Yp6*(js+tV^Y{i3AU4C@6BVw*6&{m)Om4Hbr2)CBMz#6*qE?abo;<>JklGt0<(R z78IKbEuB8kXBb7#5AC0MJ8Z)Jx>3;UM6a_4Qmc1 z%S^LD&l?F>Se2mk|D)*4!;(zf|IN%7j-wA>j+kUcd2wQu7z(t~SOuCr8(Z|S1_5oWA#AP4MRst6*w)+T*oSc-R1{#j8NBEPMk?;$fb0RHXLmZ zuHF9OBIyTDbl3Bti(4!|G&}_9D@r-Pg$iFLYjsVWFQVFF?snTBC6@A)7@w(YqxMRo zSYNctjU@uRH(I10qJPY#^F6I!jbxKvS>-I*P=2%|JEbEGpLyssOW9gkOsHeW_mW+W zkvH=B?NdMzMNPlUUZ>Gzm$*=Fsm=vcSJq|eBB}RZMSAgM)oV>oMF`1vB+INn6Pd4VGccw^r8H8OGQ^9r^fAXanfn9 zLh;3Pc|({CY>XYAH#|f0Anl|^IwI)**nq7>1a|^tA^*82BY?kRnS7&*XBEn#)p%a= zEHj}+Mpl}Q=R^sSPT^g>KH1R1N+Dtx_B{}}M6`0}N5-_I1f1zqIKGWN5|rF*)pWV1 zeX-@{1W?=?6A14N`}N3p#p`P=Bc~moN=N*fx`@#;47=f_Ax36#wSk_jsn0YGJbMb* ztj@40mG#yXicH!e~?^h!>L=@F$Fy)t`??nYS&o)JU{bZfE?Vf*i zU;Q`db631u_H+RHMAz2N^PwB3Tgg{M$W>Ev;?sg*k3AFEM=li7Or$PKhCm+o@9TZ0nI~6 zhg8V@?Z!*F(+1MRZ?5_7Ik5(ZgVSC^?(PYMVsU1{C_NCxd1Mia)D2PZs4ET(%Ib@N zcG{=$VV`ud@?L$r-F4IOK!{lR-#!D`JAeF)nB%#rdT(t1zzJ@>Ot~R!fckPeij&na zB(6AHJ7?1RfoEzNk6hvnAcA{^Q3w9mkT!0P@<@d?u9Lcaq8$*II~6LkBYz2=6D)gO zkh?uJ_**&?#*J*+(EmM*Z+d6&R|~vM7O|b2_Ez9Qg8MeXd}wfggo95Dj8%b#HJIpw z6CULTi$N?~(-1rWQv#3w?+HwS80wP3ywcZsUy!>e%4sJRuTRQ7nlj%>hYIql>~~mB z3DNA5v))Y+K7k?#VaVg__%5m3mrme*;&T$^-J|mDiKgF0;7X0)&knh}NbcV+C>wzJ z7Qj441c%Fd+TR_`YSS4hp;rN zmLBj@>Jq8xPofexi}-8zS3IQ);zkh(V`+>E=)n5_I5btnlq^S%@dTce2v^luz|83& zi6V$K6VNoaYIIZRz)<*L`dVXCP*?<$26Lps;=*7n3x+~IB_?2wh7MZfqKVOS0%2H_ zMX{KCB9hrObi&#)g}g4&Iy!1Kt6J@D8ESZ0NSu?q?{|2Gq{y0 z)p8D7V4_vJ^ut(n$PA{zrEVywa46)$(1hFOkn(y*wVup2I2ZElAe1*>`ER)bzd_ix z1*C^m?m8f+3#;4_0-r^w2UqUS1>^4sXa@{FtpcCwbtzgd> zg3Aru#L>?iV{sY-*CepA9|~sMR*rq|*U_-oO9lNP@Lzy&TI8%nLD|1@FG$4x-G)E@ z3-d7=yyxB0NpO}L9$YBr0`SKVfqayG@bpkc-PTsS zd%|B0gEbK3C1?FJfB8)<|KGIb>*Va|2>-k;U$x?Tbwt=8?1-l7#Dc*m{l0rE+Fru$dZ=V*CC?Dc z$YnGoo`Jo5yp*^$_GRMcsa2l@n=bPmw0!3DkZ%Wn*h}C!GZd7$KJMaU3Ih3SI*il5 z{N%wyD__>jH9-PS_Hc^;<~appvZFTo%2&hBk&m&1nxF?NdDtNtGLFMc_WA(1%|nAV831SQOSv_2Ja3)ArWvu_;;D9 zN) zWH39L?b%$v(bnM|EEunFU&wngs5rJ+LF7+(OC4UFk$=<(o`1TeEd8x`7m*(5D#E`99rbxr;~dRt$C~25EN$Wlu#F!B;O{47}15Hm3+YH56n_jH7AOS&E>goObIZ z?u<)dbnDzT@cI_+}&@qc%WUHPFzRASb;EGn>_KS z^o@KucM#?iE^uK*gw983pOV97_>`rz+t&@&C6G62y+fg%vZ(h31U_8BIyv#!fz5uc z$v3Xdd#hK7zY%zq2#5=>y-82$Ci$O6hG)l7dXq)Z`o^6tTQO?z>`48<)pPopg{6Ya zC2}7|fv+^ecN|Mx^>_`$D!geKUO0ZC;?aE@aXz~m4)1DUrEjg69tt)_z-RLhW@0*d zF+NQZzBKs0-Rjj%3Sfu9u|cr+Z&*yZ*;KUP}DN&_m9DQRPNIxuUv#F zVD}F!?dgyW&E67h>zcHGVn088f+QYF_5R2WyZrT`Ay^6z9*1*9@cc;-IC{>#W!rj0 zt@Vpt_5T@Vj>);~3CsU(hzP#-liY$5v`J3-%^SSt%`j+m^zOMWZxi&g$qHuP!6%q)?!ni700eo>3wyvaZ> zl;`1Lo}ctYi2N!6<`t+XPR#>_QkTOPB^b{&pNA^dEd}uqEFCOT2M^;Rcn~BHf}ryd z>2JS=2cdS=(E!2}mZ=EBd9%C}sQNiQo;k@!W%b91}QH)cgL79<82Mo4&`N_l#%Yi4a zRqP>=)48bflXCjkE5t#;TJ?}O2Wnmsw|yXq?-CeZs~XIt2<(wCi?cK=CkK5CmcUl%`Y$R zykP7S#hRYLwAvT!*cLARxkpF77keJHC1syCSSgLY_r->@v@8f~h5M+KR|iL= zYQYBz+(#sb65O+n#y9>WIQkV9RPkcIa!8P{3eunm6%Bb08fbF(ykMI%`1s31L)?48 z&DSumu!O%#}3k1qX|5K))cu$?w9rJ@|CXvcw<215aRq$ zP~=d6acp8fXN)b^74k_!sE#COW-5Ho;&;RtRxXRuf84!z=3~&=k3pdw!@7UR)>bd7 z1%kL9Xxv;YVk`VU?M%p27(PPIjucqg!5l22U@25sYII>)Fe-m07Z$9$k;aa2T@bMP zyZ;wtd6W}3MFz2lqDISu3ib9XIQ3KG+KU{-JkeoB=GY-?eW!o^@BHp!!`3;&YG&WJ zW~T2aKE1WEVL$A4sEpo0$x17IdPwd*CI1hj8SzfOVh){$ciK^>+%6JOpI_R&U*7Pf z{?fF-w@~ocK>TI4p2+^idqqv9%5a=?izq1cGOCY_N0dv@{!4P@=*0BdPG!f$&hl4V z<)dqiUrxW3RkP82zSBEbvZ~j5L zymjt=4Sa+3uYMk9^Y6Lr$7XmWB`3ej8urTYlFXLu2zYWSZCG?cIPu5zxXkYKd&bMQ z?tUBClGla3zX$d`VOb#2zvI^@P<@LGmguLLdOQEkJ;`S#gD<3OMg8)><0|93xM>%Z z+XAmYNcYX?6aLrsU?FAAKGPnV>vro+e%a5pH7nmd%)WZ<>T*Rc;>!1|jg3J`w#M!g zBdL8s`i1r#86JtJl8<(m{r<-Wk8XCu`*vlzdun&4;>@aq7FLE)bfWFaC5pG?QrQ{X ziu+BM<1TJFIUA+S=+3?zzv^t@=yHRr&|-#tu;1tV#?(V!=D3rQ$&IeEC47K0<9-#G z5>Wkr$UmO=@FIBS15yZ71`wQLRT8Ki05f?qz;@b31_V)Md_ttBSq{J>Mf^Z)aK*{p z&!4f+6t<)o!H98y(-6}lcDIdn+Bnw`;ncu9zOCmL@|5Kfe$@AdNMRM-{0w#W;B(Fq z$33^-x}V6kSzy(uSO>y(FlxNc{nR$~ zziEyI1@r3T{I-F#Wy_iJnMvvMzC*4Q$*5OVYpyZ)6?sz*aeHfy(gLa$<%N$><%LiSVVeB*m2bw* znoep~seZU+bB?*PRdm=3P1&Gi$SJV^xLdL{7DKJ_*|y*JJ=#&(x^Wb9_f&S- z+ih<%w0!gm@si4|iN6OL1xm7Il_%C@h`0T~&!}0wux7~=v9f3#2eb0&)Ci53$7Y=A z7hSX3Kwej9b|Z@m8{B@`TrA$!oH-~)HphwHDky1z-}I1faH8n1Ei6_)-&PJHoIiuO zj22kgp-cn#T2d7}#8A=h0t71j%Q@D$<0w1h{BRIeUu7#Y!I;a{JUHD1 zC1lg#PNH#iq)ttklOUXzbk=;K(NiRalh{0n3q(dxLnFL$Ml3Vwm}nBrsz3#EAVD$i ztKuBvyNv-?8mvnT3sywlk5;V{js3?$7GIH^CMYMs*TE>Nf*bB`#~r+^4xmQ=|l zFa^+>y_#y}SY{+Ayg^eK=^FRTcyy=(K^{nLv#U_U9PtcyHeL=b7?3kYNl1T{9CNC4 z*qz&}+ch|gpEgx^c^1`f!nctx%_`_ydTW#B4VW&;a^$`*AWm3r9e81}a{lXG%y^01 zuJuA?3is!l%_Ronr%~jZNIl_EtHAm)SVfN|;5IJilv({8&sK6cVp=ZM=i)poorAAj zUT}Y(>gb&;Ym&aou~YPAy{HneDx)cAKk5oNuTxibs1~TEfxBzI%8nJ&;s4(&>dOkU z&KkH%)k~3Vkv^bU9bp$5vBD>VfZq(X7h9^*naA63&6ceDSEYjZK+eE=d<6MLcje7M z+uE!}iy){}?iJRce7Ae-Lz7FPVcrsQ@dI|1oRj$O@;VNMzJ2a&3c za(59Kl6}1PTiq_=q0vfuvx;wJ+h()dG(=pIV)>FPo8(M^ZI9jNz2aF+&*E8kb}uHK z4qp{8VcgEq+(-uCUDj5&5y z@TlJACF>%!4g_3F&AMXGb)FUvB|#-cgw05ur9@h`P`>@<{P@bj{>rsI0L(-S~o#4_I&vF z&r9bhJ<_<8|GL!X--fxOG!XB+57VDkqzi&u|-1vzCcoVy1y&46Ukuzh~Xn8nxJN>vSU zyGt+pQ8IxQmgJDn)6qf9m=jk^qAqmDRt9_8)jli{(8@{$#8^HizBJ-H@~UA@aZ$ab zW5tbW!00RBU!9S+JH45%WX&_``cCA43DXGYW{4@=QhttR8{C`2F^Lw+Ghq};Y#v7> zA@sy#iQt9>PkktYkz42}pH|-KR8Zu>tmhq|B-NYMS0T1@{u@j)~1JQWPX9* zvS(lcP1FeIL9>#-0M%`uhTD!FL#$w#0qHoL1ZXU%E#ErE^F3inh4W+6v92mTa6o-z zkrZRcy5!ibS5hrEf>2*Np`!|)>~F>X0eYUN}xbn(-(%-Ga-*cqNptU9r;8XGCa#hNX8 zUy9O9E-{mfk&Xp=a<-Yc1R-tF*a^o7*?eN2jF`;_i!BCG2$&)Mt>86vOxDp7LEHmn zj55g8Vp_H4PO=#TzTOnC5kq!Tsw`@=HsVv%B`%PY&!DpT+xuYTk4dB)K3UY}ny$b0 z=mDle*4}T(cFG_0YqRG>^co4Kf!!1(W2q&ICYO-!B*6n$x3l(4%Al_MbRoLSD-~X`b_~Z=MV6yS^D3eMsMX_C1vqh z)p}Bie%a1bL(0&f3DNHkWRS{oDWv(qc^I)6dMVqT`z3}dL^JZl6qSIs(@e>L66+K+ ziP&pPJEclb2;>*Gq!Ls0BrTLGWl&3Hl*tsD#rihiMk`~KzK1%#R~uY%DO&NqGBf3) z3o}&=9@vIy5!p0JZK981^8dsn+hBvLF>(<`F12Y9`9m*dB)gJ>hDgzG?B@J_{P@Ze zlVBgne?NNB<};lvl9857HdoHu0MJ_lij_&?@9wO#9>tKtp+7ZCL5SIsCo8QN**r1< z%Q~A!JnMNLUi33iumgS7Ax z#uPyjdZ&6BHBXPGMIH&G^Vo^7Fm-jZ(89J3`z?$JM2L^;5N%R;3>`)ijoRuI(`t|| zj_dA-Xf_kH^Vl{~Eo+XDH+Sr*c!de-G#!~xmFM^rzlwc8BM61!^ z3|kjIVOd4UG25hCS!vZ)MjlpYak(NaFP&AY8PlpGOv))JsGYFa?v##@B_^X7lrTbT zGQUjKxglH()?-&xy`}&fBuulxi0oVr(r=25A(on{g$xIUjF@Mh=V?g!P>V8>yd<)V zWV9dWuR`dpby39XqB4uxkJLqyKR&x9dS&eW`TjP_7t^j6l^o8Ijz#F<1)*|b+&-v6>KI3 zM9{vdDKF%XH8Q)yVpbvFu2g)e<3vbZ12NW2OoW1`M8q(2@PDP?K+|8mHvk|eR<+p; ziY;DR+WR-lq8X@TqM3Z(;+hltkk`3|nlU_m>H{-n;KL@No)|B)$&^~3(m`i+#9|1r z#3P=fpUQenTvB6_q1f!kRiS#DIgwqM#%AQrUy1r}@5m`%NfuwQJ7k11V(=+%$Fz!o z%yqX~O_;Jy3`#HM?z4S<0CcB>gfoVhrJzcRgF(TM6mXUf&gELyLx6e{P$j~T@X*hH zUd@lSE|ywiv2wTj^{y}s|9WJT6T^7`@&!wek5W!9gE_uWA0Yw zF^?&Yv{OAR`l0vni^)1^3&_X(IhNX85OtQ_Xel_hnIDZ85<=1^T^OXrLPzY;vR@dE zYCgh&4q1dlJ9Joq6n+7M*YRRXnQK2>aMCewN7+_SP1LWvxT`$dF%euxkGTQC$4hOW zKP3&?_>3HP%r-w||4Gh)Quc_wGxS9-6`@-+HW?aXq8^h0wfh8_4U~~uMSocC^fTKq zeM7`NJ}q+O+|%C~|9wbA@Ev}$Vs2a-OoozlVvAc8Bw@)RX6hHww*@izKSqTOg@LlHG*^gZ)GD!8w1CF!ZBtsrroYqc zpKLE=FfwIZiXte98giygD`ymY|4xKiUL?NDBiyBc-0bZ#yMPmIToi3w!|hS-NYz_B zx~obVzl9yNE8$Zln!1Bf#{~g3vDr=}rVT>L?s`hOnP~kxiEA!f(h-(Sp{GQcIUS%p z34Voq)m^nYMJJbPY%|SkvKe6OXWFuq)ZJxwYMQo(AI{Tc7t8X#)e}kauItQY$qaDt z74E&snptfl=V4w<0?G{Hg7mKuy0uDe0~2HBnpZ}=!MH=WuoM>%9sFtn z#!P^415gS9V(EwwIwFsbC^Mm#q!y#64bAL;f15C^!y@=O8_Y!26g*=L?!vPP(%gww zBl2{i{xRfw4Q)X~%h8jcoA4DT*k@N5QwnoHAVjA9G2@{zAX4`8-FP)3Rtj@cLlG5g zqjgw3Kb5Zkb~iH;#i^~n2!>2WBuu4_9z6WN4%@;G@IAA3KsOmb< zj6U9R?N3%z>fED`iLsv4~niZ|gWEUA>kJzdQV%4L= zb}Xx{3>|$(9IyVZ!=RvGCWHA!N4S5H?8ztYdaxfN1BEf0&xvftszSfefjvjRh{VkA zj@eKXxS{6DXRWiBOsz-suIT}GhEAKEHz=h6Kk3DGN@#Yn=KI`#&v^@o#pV_B45u4M z{+ySQSbM(>f?w{=xjfpre<$B=6uMOoCGBL8MT#I5lv1FVBtq?#nw$#rnT1(ewce?Y zL3Y>VRhwzvyF}u4i(vmFb(VUi`YcWA4Ub77?i9I(!kp+ zjQ6#ShkvfdvkI(H`kwza{~D}9aoLCx9!4bv21WRbl~0InicIKNR*`IN^}sJ8t7+45Q2*#XaBP_w zbtmtTt2q|7V+F0I|Ei{rs^8D|#7^cCEc*Nx;uW)V>kJ3&&oWdmI%1vrm4^?{|Ib8l z(A(JQQ~FJlX;n_W8jPoYaI?m)*i6oXlH<*kGEMG7M>~`(wB@ed7Cx=T`=^h8m{y31 zj|PbE*RNbqW5W5gkG400wo`+S?6Szj17fzt>#g0pC_tAT&D*ZDPxqEjJ79KrAwEB< zw`_rnqZ)6v#{SQpl#HRlcs(g&nNP6}qKxXD&K}pC^UFL1u~*41B%8&lTMvJpCe66y z03S*B`R}6h-^ewz5p&}XF}d)kixLgBwU)d`Mq%yBCBrDJ6ufP+8N&&$=E@≪IU)}sl`}gx#f1L0Yt1D!#c6({8yb+%L_0oEZjfo3z|Mw0&qHL@#qJ zg~$P%9AZG*lr73Cu=SEwASF^ahb4KAPu;>ZYhw+ut_|#uG7r`X1J1h8U0QJuVNtdkc!={q?PFwZ!Xq!>F%AgF?yD ziwxcFEdavL{Y=4RT()5@HQv8>dNLvSLQ5*nTd(fJ`D>?&u|2)jLrV5vovHB_MwgiA zd+Vz*DefgT7w3OL>YCu+hbJaP3{NHC{EagSwO*%Z`f#4Jrpa2@na?R7SnBHHqn^u` zOKoH6Lt||<^z+7m2i~t*a}s=w(*hi4aN2gauO4)*^SLc@tz+F-2*6EUr@4-COnJW6 z!Jl8rYyEE2!L9uxZp}_`UyaUI`&=(Sj&&}#OxgMzZi2lv8d)+LtV7$F74`79aYpz%^glIG7cNrO@K8=CmlXTg%8BMGh<;1w| z!ym`d6%bf6{m_D}nec6YZzM4}uqfX9yl56NliX@Sg_9S0v5X6!>Dt_#Y3w=sWs8GYB@Bf{eBh7&62 zn1cQSLM22&`yxHD4sgMhap8a^#ubDzwE`~%V3<@85G|?+rT6$Qv9odk9d30(ie)TGfjcEK zXdnd2oD@3P7P*qI&qrkr7|8er6!lrPy~>iC)8TrC>9r%uTOQd|mCg*3aN>L%8XJrI zM|>~7;ckg1xPts>s(l;iE3a79({OZmwyWdK%&HPtL$rHTOq?bb4x65{z7X@|@&w!|qtcltz7L@j8K9ZF>fk~!4Ah*m%cx34@nXPR*PeEdQbw$c`1sz%E zOpzW4XnoN)2IBSV6E|r;U)g7xC6xow?C_L_8~fX+&68yBK)vwQL;0D@qaJ-#V(a3H zS?Ui1_P>?raa&Xd^00dNF*ggBNo}Xy;~_ahF>c2_7+E_X>G!C}W=l_w-7w^?-~5E~ z4z~-dyrd`&e*A%7gO5_0H9R`Aoir~y`@X8dq+MT|T z&i1av&YyH4A1XC45$QPQ*vgYFB02Gk4&gc_wl3l-Xdw(GL)vB|Gr_&;J27cMB>4v| z+;d)HA2KG!a0Z4}+%850bNM!j0E}4WgLW^BKG=ne_FyxBs-1EhF2?AztVMVRih=EK zp#`r^wr*50CY}qotrsCy)`@`#5x>w#Kfgd1LFVVFO|#@nCL^n&x3=TSB2^E0`l?I{1tRG%^zm`|@s`NBj+mP0o^gODgrN_W zBEpwHpc2!oBRT+wgT*OzS9nNP>jg0A_?+@HaNaG-g%p4Mr3=XESk=NH)LoeKrQECf zRGz7%iHrdb-yDq!c}u7h8dcqeuH-eYj@6H<-0ps@CLrjLP1U`wiF#fVE8geosH2;w z?C5%~fwnp7iC^p!blPB_P2(21YZwg3YAqyR0w^=l?zO3fmna`e(d?aHCqI6gwYz2M z94Ajaulg(|zp+JoC8l7L=W{627v8X|Kh$iUZX@eOFb|zMYj5zfn0!MbxFyom6weGf zziL7B9e`SISZLe&M>q2Nqay1h73`FO9;JBqFsVfCN*WSA)r@iKa~5aq`lOX?Q7@~w z<$9-4S{b`jU)!l{r#2N=`nVSzl}KSzuOTrWixB0;QPQS4X_vc@0ls=Jq15j0p3LFn z#8b0&!`fg!BLg2ZHMwJSaED((1FmEaM!Ug2!riIIW|faoR4o8QdE9>^ZOHMeHp{zE zrrcWUN}0P+Nk66SnT7(N*T=p$D z${DM`B#yO_9^hcs`LB=As&j7Cmmd8?u0?QKH5fkl$i7@su@a)c!}>7f{G4t_44gE5E$`PK@6tg;ND=le^{ezFWBLY|CE`?9gKmmsYZJ20VMTIX0nmJu(m?TPqdd z0svV>^ddT~L|)w4hdnuhy{g0ZSQL&oF!e(@z{-ScAnrOHKd8o0BK(a!{41XIoOIV= zA!^O1T%HjM3XoCm@JJNQ&IGk*!&pmhPHdPP9Tuvir$xf7BnpBS>a0Z``2!w}aFZU3xldq6*0Y_#{JEy`$iEZxC*{9Ul*B=qBZ2WIGa0WGolYzU&wj7UeFG1gXO#E53&!EWm<|OXF zLhJWUShfeQ!-QC8@?~Dz^@GZ%M;a7N#<%~BGc4|(<>70p_vy5rJ&;v_PyMs+q6d~C z_;ba$-nrT$sr3*9YnG7CNpNZt>-f({&x>r%u<-`E^&2{envxEd;*A!Ank3`aBAEI! z)}T$ec^o%!0QaNFx{rx}Ne{fPqK49e&sy+~%FvlKtm#l; zFR#K_Tz0*SIQrkILlJ^r5bqzj4kg}NCOmkVj(cqpO;+H_Mc6M-&`BUJc~Q_w$DSHT z99xZuV~3YX!gMQ9MMCre(>M7kXucY;R*fhJ5t5bUu9pad2)pSSx<{3D$YjBY&@u_C zM1n3h*`L)Sd>1#C@X%Z6=%38McT%t0{LtV2`u6ocL_Qs}d}BZQ8xu0ugxse>QFD+C z5iPU>o+~=cmZLU_knJyl+(wHg7sFl!gi4`I9*h7%T{R&$s=sqGLBF`=-T`on1YBVw z>()5-nii9A9p~bPyO|nqnJk&wxC^E!JEO%8sse|F)?dZAhkfV>ODmNJMH!$p5;D6l zFCm)2YJn0B{NMF%98Xx%Ybf<81P4M>MJ`n;K&8WHqTFhD(FGD@CeJk}RSn+9?==9c zXB9ck%xW!iq=@09MOLa@fBxWpL|3Vj6wIisJ%5i?-iZB03os-Kl<8ckwy{f%sfXYc zq7a1+GZl&3ptrpuI&ozYGoW3@y_@iqt8wd8s~aIek1njfoz-dr&O)&7G85<3xYImb z*?lm1e%gzkRO}{3!rF!AzQrr!&=0Wbt-3=)gqvh!* zXZq;)&qC{NDRxwfzoEtcuYuDe#BwvNQ(hu^7ID`>+#xOD)FM`*7usA;Pu_qZFj>v& za5tjs%SLg-(tS5962+5x=*kkJu;8 zsRaQ|xHEL|Ew#Ynse}B*)=I3WMYbCoaW-z&?{loL32`@h*3+iIK_TwWJn;fsW|ZRO z5?S=Mr2b{0*FUK^U$W0dA$|oEHyDUH&9glV!8Jk}8`W5FiF#g*t}kH8O&CaJx)x&b z<%(ZRaO;#feOKsNY1S#>%62x;qC=lkA#&Ad?Okua1ns9n8~_nhgE5@-$+7rr5rD`*hCqm>(&}-?am_>M`5b;L@YA1*)m#nB4q6`G!!FhB(#8G|0 zv5trSZV~;h8kGZS-Rc9&5T2re)M7QdK#RgDQR!Ng0)myT!u$`|a8-xPr=!RSs1pQW zn~+1njPE3fMjloU+rTBt-WqUonYdXkt{x}rXX9q+xCc4wuM6?ZcTBY=q4f`@)^p=5 zrY!J0h`*=C!v4JfjlaVX1*p98b7)02HN$Seb6S zNKrc=`C}kfLyy6PROIFXN;NR1!+*oH;!e7R8h{KDa$d~1+ko2o*=6E%>?GS=#Z%`N zd;F|-39tY&AawOK1g=Gz=s=Q)%ZB29;2&uO68aZ0SEabJ2%ww|cv-~Ei-=1S+=b%x zy%21d3A^Jy<{X`4qR2W`xSMS8?NnTn_JCCDXXgP-{)>C5FEqIx zJ1@lFPs2_a3V)sQ#t&nkNA5l}j&ngRCl!7`bVzG|pyNXJvwPWT`ft$?ZGwia5qg>+jDW+owzT0?4}6(PbS@NNikar&?gip!rB-Sl=gtM}kP7z=E#uz~*-X0MV4q=5g{#iSPlL76ah%gLoZD*L zEznvoMJzqX9$xtF5~OgU*#9DkUx&wC)TQbn0WMYcuhmvsomjYt7*jmxkY2sR!!`*4Gs2GT7>XbBW6}v zoG{qkAaEE&4?=-`Y}8;8>W2x`RS=MZz`n)DKmP(B+=q6^k!OI&Tp>yo939Z6ImkvS zgnE59MlUdY$41e4mc=!FmEiBIXVfq5=G zx%)ibZZVhl6q;Uj=e|!&CEYcZj?Bp$w7DG}H-E~xipCHE)u7v)2h;yJYL^OFZjj)r zpKi}DuzDZ?+_VrV8#OiSISJ--xiHsk|S!+_q+%jcbTP@t_vPP=UtQBJ%F8M z|L2bj+J1?3r&@bf>&HIC=}>`fsu`PVY`GNsQ8AOC#hh1j>r=4QV8Hw(@C}G(5(2k- zp+{S zci4XEi~FuvPIQ1nHogdn%{akUYir4_9KYORwFKL-xXigl$C~N)#!y(fXyoB7tl9*8 z$=TP-#&oH07gbnAF@AI$$D-f@%kgt+toftvybc>0ysp=>!d*`!$q*K%t|tqB8-iO$ zevh6+FpgU&q2CEP`laaox6Zd+8c)@=s0UE zID&-)Co8nbEtc(OGm5T4{;|wy%@b|yUBYHq@A7+nH{^QaB^~yZ+TVT53fbJscEE${Rsw<*@(b5XAldmF@ctFa?mXT+#u1Oz_&0YX@W*h!#?U6FyY zio!J?BQrj3Erz&QLa&4_o3v3IZVdXR0+}FET;jqOz`LG7Sxe$wl0l`avM%Ycs0*kC z3&IPm{sUrm@t^=30u@0tD}I7X+PmmD1rOstf-aL-A2MNAiLHN>{9PeB(ul(Ks40&( zV+O{ZNI#%&k2}qeqsP|I{97S3j9V|$ZKlUP`~Je%-Nl+fe0eWsPJLHS2YPwd!Zm~i zTYrTVyFA*zN+I@JeG$_?j=58f%Nu3?a__ z$OAWQqS$9+KL%9)Zy7tkxaZ?6;%|#z_Mg>V(yQ-Ba6hCIitq<+fHpb^?k5PoC=ov_ z4Z34m{f2GbQ;a(v1{}Pjn;r-B5bT%D#kVZtTw%hoUxNRMG(h~UXIlM^0oG@pVh{h) zlNW*eeDvQx-tamsO9MU&Vf?YaH#Ip-e81`b(~C*cCEcm(UMJM+&mf-7w%3rmGrMB- z@!NI^W7BT9H~juGuckFTA?Uw5W;^bkb#6_z?|oUZF4{wZKz8q|O18I4XhOub<*B-s z)6TT8%=?s&&xe+gFt$J@6zb2)}N z0Hnp#zvkxTJyu7{9ITG7G{(|gqfgOxC^~9dLuK#rAKUNRFmA+@;A`M6I4g$~!RYi~ z=p+|b`nWnS;*IsWq>8USFl6`o^ZZNYEgwu!Lp{zVQnjen;Lq@T!_SmGx%hK;M&EeN zN{{5}^BGTyUj6>##V1{%B8PHeS-*JN~4E>>)J7V=j^Zn7WvHDnD-fXjf9k zzszP{BB8@htkfK7$w_XA0@ujt$#Qh8S_cPtDh&dx^U-q&;h=cPb~RqJY$RC33Kcl_ z0=3L4Xr7Lpw*3ZiJXRpfvjhzBBCtLJdO!ssYmgr z!nekn^uT9P)Q!dbZA1wo7}?Y-`SVgkgeNfpKmvS!Egz%9i41mdN=6&9>v43*(-U*0VQ8sWZ#|jiO~` zdUcY%@%UHK9liy6u`x5hVFPhX93k&mYKDapP0I;>L)%h(Hjh{p7%Z<@voxwFr8Upz z+Au=Ah8u*m6ge1^|MsoTno_PSKn@f#fZIzF7wY|Y(*s9&K@;uu=Z{~2dxq4ha_d(; z61l0kOJH7o2JEduwX$9+Z^}O%`>M(POZAXVuyPqNEUtM3fz_;ePS3Lmrac=XxYqMv zz>0bhj^`W(U;uS+(NOI_t5O`}$N@pM>u0S?E#n9x7mS|rG%A1T@yxAX+0>HT6(0U? zo^Q+&s0aOcG7RFb)No7Tijo~&2N*mK6+EYx^(r?ExKJ*)(B**>N5bjd3fh9m%l!lP z`mhVdV5(pz@@rEoDmHYDuJbsMa^-)Ea&n(%)Zv{exUa$(u#1<@Qu6URw3pP5QL^)= zaRs?t1S1Vd(b2p%8~ZmAbSjjxmMzC8PQW*j8UTTG*gD%ntEwhYZS=JftXaZNp28A0 z-9b326KuV&!p|K+t?FScTkV*JSM&@7>gep^UG6L!&4en|4F4>3JX#WB}HU@Y}>d-~`C#TeHE_Jb+^^9iiQ zRr354XNl(po3eR?_W}Od{7IPO8&A}k60!BZM4{d7`5|Xdy)vU(P8s$@c`D84{zkOh z7)(fi`UrM&I^J=bTSa7>E3#VJZ0b{qEK)G0sKsU5;Nq}<-3Q#>az)D>_J~dhtqq@o z(U$*eO7`pnih~O4O1BG=DD1G!JepL*7 z(Wf?2|L8H_k`J|?rrRjUfJz^2Z_tJXG}+;u!Xpo=+)|njo>E14EH{R3ZoM4RQ(fuN z`=K_o`f}mHiAds-uEHH~u{&2B9u$v%oaXAoT{)dQ^TF9njd49xdP%>9U^?(FMx)X zk>^IMH}{AM2YGPvBv?8AJ0G(hG|Z3A0wMi!o2`qm+AyIG>f!;*k4>oH>T9fuXLbWyH%F1a*?LTjr4WC63B?^Rs^5#? zStn(lc7^Frm*2~w_DP3*>L&glMR(%Y#JT_HI&0?q_fKtMJT z5NQHPQE5*=YN@3*lLf*eqN0tKwgFL5QE{n7i)|7DMnxMCTeP+(AX>E46D_Ad>fvzS zdH;d%Ntk)&nfv+P*FB#2x+makF=@s4IGr@o6L)1!cX(A``{qIQXXju_v~^v38MVj# zS;t{X8I~?h+0Oe~Nla`XZ~i|lX!^NeH2BWft$BZS3)cRg6_s)R^w#;F&%f;>`DRDA z{*O`CctvmvBT=EvVA?pfa}W7Dg=g)-E`>NGuXXD zlG@AJMKRO2zBCmrOL`wr@&}l?vmMM=69j5zg6P60wu+EiQCpUn!wXzHGDgbrk}ce0 zhn>-Q&p0$&4QAH;&oln{yz@~aQN|dfY2Txo@qr7|v{R1csQ&30xU&q@+qI!t2Z@{d zEVOpPI%rJgx)u%uuNVkW>O4I+fZ$;h?DdbCZWh+7nQ3wvcyvJPO<3xcQ*xN0RX}rSDVY3%|BzV-120ReEJ)K$^e?=Dk7Awu^ zilmL14qYylymkupmikllDhkL9HIm|}!CN|1jC$3^Fz6qP(!$i!EmJyDE%8kHwTk~F z&x6Etr4F78D6)suKuw$ano5-vWj%h&Af*r@BGRU``sj26F;m=Fcv)8}X-t#o@>98( zS-DMx9l%hop76v#)c@3zYK^)rU{LvElCn>ixk#+7LY8N1!?&8rdvTOagSsM<2E*KP z=d|4?5RAYqkqB1(t(7~p*?JxII(*>K*I6B-heAj>PHe3>^MDY%8zCWueR+m8+ax5J zGQLDW&I~0N;DU>kqypom?D?4dW5f!(4&ze`Kyp5$D{yKzfsF-K(U~UV@&9Q1sM?Y$ z%JxuQrj2ynrRnp~v*iNv=0Wlv4!P?f|FMCD$XOUzcGj~oOMkgyg-%Y@W#M#L1BCDG z#9x(tc_v*^p9?~5dSO4e$xi%vkTkYS{mUb5k!dLK_wenni9bA*U@_t~QyL){23d%7Nv6AA!c4jd%N5i}c0GrbvVT(dJGOm}SSasM=qC|2z|i z@uVb{hLEGYv@P41ScR8W5k~MEk_PF4Gn$dqB&i*ppFw>gywCwe%mJPTV68)A#%b4x z2|l1Y(WqGo5)dh2B)i9TjInzXkl56#WZ%QEW?kVSnL{1x)U3=tYs*y!rm6!Z_rGwW zD;IImf2vk-+P$TO6$9y02ghzpHy(|HDXkHSbDtws>P=My*C(b}DDutJ;AY3+T$~PL zXrJRL*fkaQv+HJi^b2wS*$`|xj;_$4%~E~)?qu|9y{<$`IxHn6zHZ!P&=soa@*NlC zRC3uem0Tl=jCo7eYQjJ?UNv2;o+A0YkDPRy=PDbnj| zBsvAI>D7Sl@{%@lkW?cTzEBctmBjU9-zL@H&de5-7LV{(>Xx@jn>tqnY7gn;D$+lk zjSH|Ev9A)u2L5-fSc#Ny#4%QtGO-Tl#cUwfG>DHoku2 ziz%%Ye=JFL82;9ujB*^%`%1+Pli)C|pSe z0slE5Ngt6^l^Ac=1ll!m9R$Ne&FIkMf{u=+BN4@*_O*?)1=QB!`|;f1y19@>vA*e{ z_8maCEmgOS`?lzjXP%f#PdM5~1J3#(_E**ME_1 zwfXRcBKKduaA*QRZBBOacB}S`o#-rk9#8eX`A~yl9uLFspS#CTqN)?c_!q}ePd(s~ z9R=d>KK9=d?S%D%@Guo~EeTm5>H4Jh+qidL&1nOEKM-!iN7zq*g9PYr!rC%zM3rt0 zNLV9nvB_HAhN*oW1Wy$pTEvN0Kui@DUeiQTeQo`Oc-68xUax^V2-n7Z{Ik{X_(#*~ zbpGc&3mh}*Lehb$w3Hp5@1M~z_8j@Ql(eb7GY`IzC-2TQ%t``CMX5SOuGeMZBn)1K3No3+S zC%HuVS>6iW?iOp?G)x@V>8nXwsEw*Ux}qsU)&NC5(DXq=^wu@x@-R{j*d&W++G31e zC(|kC%C=P<$UD)wulYdL2PLP{**n!< z8%=&^psC(K)UlgB&VKVT)(N+cY#cvR)qH z`93%n$i1mOPEz71D}tJ~3COj|E!ZAGQKc&ab;Y<_MfJ}{Uuh|HwJ=rfsnV0|1jiK>=FkmPnQ z>aIld;G|~lKM89J)yvUJk_@Nzw-Y-dq>`5U^UBcTPbL16`_FEWHaQ8k!q9AR+jtuq z=^zAxYN>&A_)9!QJ)4-odDE~#5RjHs_38>glAUnI!%gbh{NHIi)#%|x6O z{!?yE8t^RtmE?U`V_NlnxP%xZcqtQW!?9o1=JFIv+@+kC>Iw0%`lm6RcL(}Eg6V(B z?fwF_K#&xeioXqfdo_lA2R6uRqJRlDm4VH-;^ERIQmZJ^U|M!7^`As0aKul4QWY+XBEdt>)WV%ekbi2j@?p1DuF544rNxMuEwd` zve>(AP`8`f)%}Q6KYzB|?(xZ_G0PTRD~~I;w_gw2qQ&dW&dXi#zuh@JP_%pn_WpB{ zws|1`cqQ#{;H#;Dtke0;e&o=C^YYVW2Y>!x=h3r8!-WTTtqZFN&grTB!!6(}J8aXr zO49GMd0=qy>9^Hxv&{#5!&X(+9}{UuJD{|e&OY|j*3bOknO%IH{O;JX*vf_lhufUK zE%DRGU2(|)cHx?a+K3k`TkoIs*DOwME32S=bu*n9H&NcnHm;PaW2XOXW_^{G*`c`6 z`SP~+rj0Mhrq6AleX+JjOS>BwY7a&}Y;%8&pKZ+(?MxaUUrAu@_uS@~Pjj=Tb%@|S zi_US~i_d(@?%NM0-+FK!-WUIRoN#u_Ie7Ql_a-{;HU|~$U;X!#Ur_6jg&p=~mEsEd z-tny;;8zuueVTSCAa0(H?+P0RE3zE`3fAKQ3T(m+`_!g>eT$tc1gv0bS9FTkZs8>& zyP0V1y$oksL zC&NF_zfKE2-4MbE9-b__8-1oLRqvDR#NnwvZB8@APYYR~itLG&Ftx$Y0wyXYcp7;h z2oU)(@cOWgG0_$m4;Y6M*(!+Po7wP4+x3akay#g9{{$s$x&KLRknh(=%UU^ERR2_8 zV}2(+_S(Yb*2tTaSgU-NRrW*J_^~;DlX_yEPztK_FD*fI) zj~RA~nmZfyY*5zLoE0M)3%y~LZH$^(aBKa_-03zD`P<@()R5Ee!yBV-3G=dK&zuu2 zAvYO^H%9eS%iO|x59V!(?lSZ?`;W}l-j93qQkE5>Hawt<{~CKqjk`QIlNmN+TX8o; z4JM~YA0BXHgw41oGo>LeSKZ+NJ|b9>h_BR7q(}aws!fNC%AV%s;1k)WY4w*HKDy)cIdqaDXy^k* z`8D7{BKye74@k_!I+4}8OuN2G@Q%|+m_t=HWYMPqd#qr&6Sxz6Ts5PQ(=gt+`2LiN zxA`R7b8dNi@qcO%Sy1Ou9h7T*arYGtf5 z<)E9R#UGGdch2F41xm4K19R_Eb|X8PI=)U&nXU0G1Xf3@l3nMNbYu~6f6UQ^P;*+4 zt1js`zLpJXI*r|4uc)&lo!y)3i^ zWyrB_bQmVD07Fg$B`7$1=LlD)Qe0RkdsyE}*7+76Y$h+uKKq z%SYOGPs+IMhG8(piGmGsV!P7%ssMFC>P?Jm`$sqX=rb{wT7na%xgSni%{E2w*DusOi=8M+y3~q{S-6+wM}D#b>-WkDyzOJ0 z@a0;hT|S2AUJ_WNc8aLm!eidGh~1)c6|I?9EeAglj|OF+u6b=5{-8t`)R&KS&f|ID zN^BfGz^r|u7v99$^_=GH;=j@iDENK+1@s7uYg zVT6LK9=jI5!t>4uh_MZ#_L?CJZ))yA>%cTeotq+@%|0HHjy2rR&ba5iM~dIseP!Fc zEVT!R<$5o*Ztf}*oTjdCA0tj0#_+=Hiz&jy9L6`aa=~n&F1WpB)9yp-g}b}ky(@cK zva}Y?m!^~;iMc7KO-HF2ha8m=M; zJKe}pP8?IFN5T6>Li7?$$9_XsVyaeBI89%rLKtmMoiA0*e5T28#EMVrR{1RZ zS7-`Y`!S^Uel8)$O$Ia9|40aq2bNoq?9=XRXgeg86|MRGYThNBF1$9dHFrz~p0{C- zq3F-@>SVm*i^qd@e%*5I?APv}LsKEOY(|9%f9T7sQUAC^TCQ+EsC9<7*daZU7oIof$|YBN^8H!60G^TnR7ufnlYj{koZNFpsVoQ*=XQ=v0fw4IFDW_dELBO zpDfro5aT;k*|NN)n#=osZX~YS+$!%pA()dk<^Ds&Ebbui<^~TY9#Syg!Oo9vK^uat z0DbqNVeaZ4$~#4h(?Rt$yz@p%#QQY`)kXquP`*PV&tX<{SOlLybEDdEm)(tFTi$Qg zg_6z9rH(V4>&E8rg6G6iIDv1(kseg1Fk5%~@snOyB@@IJxUPpc^#v)KM|as&vKn08 z^Hh@We)vIv!hOfg@V@x9F49?m_3_{HyaWX$2nR&U1bFQWNmKZek+Q~a!9w(!xN1>T zMQXCw9TVxfkD0!^PvbEvX^It2Q{U}KW<7@Am$ePEZz$J2ehyKx>qX4>?V8W;VX4sN zrS1BaV|BO({ycZFs4dH_TJzwQg zk1*^ilGg0~1i+1}TniC)pUQ=5cUi!(>^PParcat!g{GQ1l-pngpNtAS5e!K2j#2w+ zEs&tmEgTWrQSUk6>A0Dp1)BdWa4EtX$pC(Xns4O42}gxjTi6nm<23)H4)3Nn&$DeP zXd_51D}^flti=aH-K;=4(ZY#Cg`LOP2?sx^QbP-OzV}$CH|B*HYMz|xwiS``J`_dioZDgrjY z+e44lz;Nnnn3!TBcBK=lk?=(TKaZnow-Odi5g^{3xZXFx98ixB0cF7%YM&>Bhy_&g zM7Fj|-DhCb*U3l^E%Wo1rUvy>k|uZ_=*3yWIxXyMd``7`(xyI-0avs6!JWwx7{&Zn zHe)zciI47*g;*`!+~M$QOI)WVY{8PZgC2b~_ksx@Hi!?LV}5VN#~L2~KqbU?4hI-C zQ3gx6^+~{#CUM!JiH0=MD$R2DaALoETxZ7jpB&^r#4yvM@CLOMr;h8#-+O@%ub+jC z$sq&y5PR~h_b<^6c%Z0Eddp{ZYHtSu1(=0aG`OlMSdZsO&|ni9 z90SmQl|?Q+ahpe=ezmW`EUZWCUCVj%FwY6|t%z_8@R&5W{TF?9Ugv>B1y2xR6{1n1 zUP^!uB3s8$-+45Qfe&?H&YmSSI~$rD^~G|{4aoLboA3N*nC`kgvnZ#6z3d%>pb==0 zaXkm7R)izql*4Tl!Rnr_4JLXA$|zCsCvhGoGmzrI6(OG0sOu1nJ85Ffndq}7h8-5T zBTN%QzY4!UpyFC#tn9>ZK)^PnGmXM&Qwzv=kcj%!k$rJhDttG$1n?|G7+(VZR++D0 z9R(rpdbr`-Kq(D}G8;_H8WX7+lP%Otci3wY03lS{VCIiC`CnCdjiMYS>x~rgi#LaK zSwarU{Ce_&-J2s8tU=JpFr_*~FB|TZ{gHh<+n66VCeAV7DRJfqt6I!_GPD4282Cr) zdQNYXrH^ggx!~dz4{(UXK?~@*_T^%1xxxU#AY3a%#fR0|jSp$QbLw!SZtX+7r~>uT z1JwCL6cr$DL;S^)lundaXNoR$VWpTz**Ic?3BxZ#7F9SEOr8oQHl&A3P_~uso)8=? zMWHr7-vxCd*jScio)T1si7UhF)ZH{q>3Hf9BPP0HWn+w|lMUI`=gpSH`r)84i`ON5 zkb(Bys*FT?vS5FTU{&*nyyWnHe0;|+lv)yHvIO8PQrqq5Iee_jR!kd?yef#hl*GQv!q zAh%n2*x%Kx z4weAE0SMG4b8`ah!rvcW%kWpJbLyXASMLGm!=U=3`?blZwpB_7@je2?Tcs8U06v3q z##PI1DPHVEpx$Ht4eH`W#G?~+(*yhpgjH|$)|%fBZwd%k_$#q?8B49{HxDJs>cajC z04Jz-)ZK4$7A4$ErADls)jz zlG(3b3C+SK)a@$Zwtxy4atJ%hG?6{R@nqsp-bzw{UbYNQqTZ<&z$zRaHsW66JV3;C z9>*F~yY<}*^Du>kchfHK|C5zH*EKma?H(U;u#!V+@k|UfIt8U z7R{V%>6F_eZq9$XyjgCt3oAECKQ|F=z=$?cgvhK4b4iP#H>CZZlU_}{aT9>8~)zNxu$ zx+C@C*tuYnCCH$KTTJKUZ!2Os{TpZVa*GLxwC-)fl~tkHuWv zPJ*pxIK06U-kT%3_bl`(KF(@c)pt^&)g%x#;ZHD?cUv^Ze;S6PonDdi>Sz#uYsXmR zpgLkm7Fq8e*IyE0#iwTc5hWasnwLdQY9xymjuTliR3^4-gmw78j-cVTW8-Hm(Otv8 zoUw4agwY+lBp@M!I~+A-iJMy)KgNJeof1`YRD+DPc{VisZm`WSTNx5shmTVFi6@Jq zmW>2XYR``@!!wzskQfM>rQkn6yzOQPBdTK5zsk=;Q%LZGf`A3*Cm#a_!;h__RZAnWuX(trUFW~SvU(bA(UZo@cR0= zDJVmKthzm@6Ci725YUD|1HrW~F|PnSg~dgA+Q<#?s%g%6ik`#M*I zcj~XgqxYpi#H>jCb(+^vTk;N>#P%>lF~42#eC#0pk1-&2 z>Z(M_jGx(^cy5=>IX^^AUeTeJIPuoDVV~EQDDB`^A(;W%yHO@}lt2?Qh`)Dfvir6g zTfHUrl9r(2@IkFxb&5-=Zq%6b`oZxZHyI%&CX#8Wa`BjM6gm{Yo z>9IVIdUgKbyNG&P3wUFd{CeP7y;(4c3Mb9HYV2gd+*M=o=pS)@QR#(K?;yy$V$4O~ z0J{ME^9wboL#n}F)^=$RcBXzE4nD2`t3#dM-j z7vk!R`c9$#R>Wf#@D4ZgsfdpjyGqSmchoZmFSMGty@(Ig!Ck_+ns8K8kYEV+iABYl zg!P#pmi`IrUcwCRcbe!IECB-IP*w%PuQ2n%CA5T1p?PKN~drbNl z>T5H-_Zp@YPDk%mF)V*UHDN$#bJ=_8 z&nnx6ej;UyvM%`0m-LD|g^x_=aO2ub2`SF$a*KxIv4^NqP9sDYm&y6~i8DvT+fNq0 zH(?E#I#u$E^rlNbBQGR3D`tn^s&hw!k}joRx6JnDU&=gL@$Z_z)~Kpy}xv1=R=pEu`l~Ghw8r= zCS7=9EW2U-!)?R4{IBj`)BYR(+3$hT#lOY7{z(iwD#`ky;@;&WXMYIE_~*(?mxCEc zPG$d}`g+-Rmya&gp4!pO`RQEe1%+l0?f2mqlc(~o2xdZ$Zi11K&aCuDSnx63tmV2J zloh-r{A(*SX>rQGJ${ePO!uh{wtRd1;_ILheige$k zz5T2+%i$PiL|7f=?dyf({`_jsoTu$;v;+9IwT>!cTY`43j2UJ#_%o!Fi^(jX8;;?2 zSm;r~N?Pd9bi`Vr;nrxI@+l*U3Z%43-YS=g9c!&ShuP_cH6us11l7zQ&5qwe@+x_} z=nn>nX(^0|MXE**Npr3=r;QoP-C*LhwKK8u(J(81d9jR@*5>pF6W#BfY+uth*LyHZ zi-jCLFNpE&Vb^f6tb{kI0S7#D`;bPTz4Zzmx2_MsA~yy_V`Qbu+!9%5M=4PaIDBJN zF;1rU?XUy1fSy4=dfE8_{QkYa2dc?FRzY%e=5N$8M%+>9jNYeRP)>zl2ihb^aLS3y zREI;&h{P#KJ}S_x54^uP%z~*{t~9%tt!2szJbIgh(tjxj3ks&Sok)E185WZD=%d5O z<8H4r98h+T@VBNfdl31VUrI*YyNVRg_$4LdfuEJ`i`E@L_I^ZiPuv&%F2C?C`25J0 zw_=hPIq~+h_$+zaN1V5B4O;iyso#@%LzzV!Th{X32PXeSYdVF)}#SoWdb~{E@69Hz98S;mr6DGPmSSUP;I1&Tnt|Hq3S3@~M#n zOao*DjDR+3gh2=; zydasm{QEJPKo_?mqzz`3_R0hn+op&NkL77V-sEHeMP?UX$w-cmBJrPmAtRjk2#+!D<%&{99!9l59w0KMokXH{* zv$O^R|1nOVxI_bM?F3QbF`NIQ+4Zu}8a!*FY{-ytG>|qVbesw=t1Rq&efogCXjAN> z;^x$A@pL^U;}#bj)UZ#-C8UIH7`m{*BMj=IXW9r<$p zZ^f$87eL+NViB=OOkl~C#P}r@b@?~}x)16-|EOhcP|0X}1_-_lMuI=?3kuYs^Di+{ zb~`ODe@U>Z+TAp)r{}mZrv=U~1G^Tj&$P|jnCoNigOZWpjMpuQX@=X6OS!vuT3_=2 zCy>vP;*{)|-hZ3jPs1Z&vc=ZJ)yrz$(^U6ZdpEsWU(UN}8eN@XZYixB-f_{?7|Hmz z>y~T__oh%pr|uO}0CzLwpb;j0TlY5Mf7wMz=JzY(Rp zGBpVylS{cB$mtnCqpQkT?Hw>SZ&7o0@P5eIRud7F=<}s&@WV@_kextlro7BA%is|d zWNyu~TJlfD-3y5srwmqDStt8QZ>=7!TWx+k+>>>2**R_HM1SrrDLGNCXfb1x#eQ2{ z7afBzY8c6a(~ct2MuIV@*N-%H8Cqd@CY_9elH zZ3L=CSPE;|GH+(!-lGdA%CqPFgiV~L6LNE_7eU6iNJn?s*f26L<9$7NZnYAV8;tl> zGhHL$3DK=t3)7qpT{>L7{_y&$jGPt((&Vel&4D5(~? ze%SX@zFFCnFd$-Vsh{Y+XClSTE0`Y|HU2INa@R6)^j$-8;L&{d|U}OMtHjUVGL=E9ZWD(7$u9ms4u-B)nSrc0&W2xuRpwjN9$gdROGB$;{aarCd z3_zH$^ns5bK#c1{nVXA`xQ(g^Quqvct4$O56eRczsjxJVij*rMc-mnicF5yVf#rfL zM07cn%t*bC7nmSIJl4RnX8_MOVUaHTDb_VuxA#aLeetq|I^-aEbebqKIcjONH%5me znd>_q`4|Q$k&AE%q%X8ealncNn7jro`g|F>g2-*`_9@nvF{0M34KiNg$S8z+?XpFx z-qQ5u>N!-kXU;!KM6 zJozw|FCT?e{$Fjf{pR)k9O-@ebW7n?_0H2+ydhACy0%_b^9X-qbIPLE6x(}vXNvZCh!A+<;FH8@?zi?a$FAoaY~7qa-2pc9aA=!*HNws9 z7ODYm!PAxdy0Hu?Z|+@QEh_AScvh4XTF&c8<(Bzd7w@+0o%OC86N3G$oo3-q>UlYa zCwB`f7+(50-dOkOjpv_a5WIW>FB}qi)GGL;fVaMn3!_3${YaSz1Hs(uy9FkcTRbh; z)6Ef2V=@%KtXsHTHtoJ1^{G_wVPr9&f|q6HXfJZ|yE(AJuWtkAIR%L~b7ZKjKLShv zGO<}|WR$9fSdw(A%Hh%pG3MYxy@~k>V%9^{mpEpMf_*B2)v&JZMspJqfdcB&3lnaDj>a=(fcDnC(&V^~z=N+V^ok0Mt64#Si^5V;K6pRXdxO$621 zXFXvqHykcmp)O8c@VhFCR!*h9W?=N9)yQarX_$Rygp!wVG$P8Vhbe6kJ>YpC1iOwJ zvCjbuM44T;xj2Ls)6K0=aAQQ^1dF&9b*oiy3q^cP#+pp#S9N=pnB9F?!uQeJB>;aP zO>%MlU(X33hTzFjtXy%PY~~SLbu9?!_rs`C;Z~w?Qz~|jP7CYJmn%Wu(OF^2%IeE) zyrU`nxA5h{Zh>9_$*Y98WgE{5c^McRm;wm^izzVzpErdCy(0MT~>#IAoa zFGGJzAs4BFHEa5%_yRHe)_))qP<;9W(8M?nMM^MY5dr7_MLX*F@TD_qFUl#HreBJV z%}5rcBVa5dXxH;%ySdix@U4hIEe7`}_3V}`Z{{3n9 zy@((cWx)s+gTEvQ_t!BV^P(nQFMR)m`+d}Br#WRrBxpy3%Xcs+QCMph>QPI@6}~4W zwh-YpU`YfsJHPv?O~kjDmoOsM4*cJYh+AUji!TZ?P`5UKhjDg`4(?0f|6`)d)%_f) zf(M&9M>h!ORB&TAH>JCA3FlSX&2jAF9MubYp1K=w?#JZ3StHwFqRz-^FEL`#NZl!i zcA3aKjT}7o((j=6%iDIsg*%P(MXb!Ea{tMZ_S``~b&q}-{_&-e8Tx_iOVfmWoY8M` zrA9JesVEc1(z-*N|4-$5!DxSJbQ{EhN99K^yI|Mw^<_CTL9t7K%I+{hZJPTl75hg= z#)6ae{xt8$N`3&~KKPR5w;f|Gu%Sz2Ud<*%MC-zNEvsHoQFk@UdJD6Td-4Ax!7pen#fCX>|15OC#UYg zoxs?E%O*hzE_F8}#U|aW+Gz zFEsO70KV4DtMA72Kq2-X+o0guR|LfZy9=kQ#Q=}`uYd1V2)_fo&W?Hwbqlj*gmy%D zf7)lp?Dd^OI5X|lg7}Oogw|=yU=>pNwVxvVJ1AEL@W^KOmKUlnxmz*H_xV%cllQ2r zlFRIMUn=MP=+8a64AEy{TBE!c10tmgAQmE8R1{?y;hKr2alTXp zx%uebVi9kk8@mO(_ED8Lukoy+KO2D90wA_bvDl8fued7Mi(2&EtXKtK>My7Qyj(Oq zQ@0>%T38s)oku@?VIKW!H?K=&-YLRL|3YcEAfS-HSIqMh@iV4rmg$w9{`^{Z4=U=m z@97>q)dLa>tRk0x{>a{o@B(^Ri$;!mdLz<(Zqe~wqo_rO`bcIm>L6ozwI*|$ksg^l z?tu1>ifJt-k|B_$hFhbwEUg@?%YLELd;Gh*#Kj_yDPQZ z4cASRqnx&Sn%cqGXC#l}nDp!~9XjJ{a-bjGo#rMOk4i*2WURJ4s zv!Ft*o?c0EoKe@}NZ{Qugog$DdIu|1w`!ZON5A4jF=4j1aZ6l{+fVJeugQta!2v2aDzW@vO zR&cYrN7V|@iUhWpxeNZH7?k%b#rwWu<%{XYiPJolqCxQ6iodWIqcK$(_E;9HcApZq zD0pK9A3PnqJ>MkEK!w%Q+?@)pR)J_0$t3e9@1oo+?1_Nxo7IS>=D7vHnJs|V>{dTa zEA21~=0()QfQQ@E9q~W-WJFLo7v-EuN;GqX_14V_&fsm@z5sqBK-`6NB}{*qtHP8I zh3}7Z->Jx$s$6NJSHgL3RLqwq=DeKcbkN^8n9t>pzjNeX3}C#JGnYr@51-%|6Do%E zMM1}FX!~{M6R7=xioPs&onK_u9%Q&6Kh|O6?Ya*q>^S%@6I<534ohvsBatD1Cco%k zL^Er0@W`L<-CjjwEkR6AGlM9~_fUWW#15bg9P0YEk0O|6j2 zc}Bif3sc)1)S);^n~FLM85ZSkS-R!Vd(K}r%PC4(S=&oSh1|8l_+ZQQX#alB8yUYC zg@Q!fSaVkvf&a#b*~}LqBxS01(qU&3Do9h`&QK zvGg?$yyVmTnb;~_M#?)6N_-u z6(}p*qg_OOv3!Cf3FFUWC#l!CB`Dj33Sy@@Cehl2j;l(4^Vww19aNYLd|(E=>#LJ0 zv%wa?TWV(86xuUV?&F^}tODQmCcFtrZr5a+(&ChDOxqGH92LX>6v!G*|cBp$% zROKeTdTl$ha^O4Oo@s1H#M$@3M@K0cM}m_t{zTjNkQR^O%?=lwAFV4gq{=RSQ$EV_ zUR7D=QANsf4gBC$Z^+6G9bKL~CmUY41>HLmS%5#2cH;Sc&-MER<7YT?%rDk{xUfWf z8a#e`&Auh#jk1#r?}ROHjI=XBn@+%6m$O}nbG)w-(3pEKt*0LZt=*Hd?(Ltq#*VE; ze;lk?9uCUtT2pR%BDzxP`lSeUKAFDqZZ@>0OjojT+f@p$VJ;QQrO^)tct%k@Ai)}#d zbDNJ}>2qT3>|VRERAXN{_H+XeD`Da(0gUc-ycPAAZUw)6k&_yWX(HeGPZsP7L3qyT3>iXR2R!P6jf+Z{$EUl5tPfZhd z!)EG}sbPb4gu4L)4jE1ID=u$L`qZ%^D{u%m94I-0(_|&QuJ?Gp+7uoaB>ArG&$|&f zQ(m|P4$N-96Yv4Er#a~4fUGs>n$s=Acg@`NEn$(~fl!MeQ05*}C|c+4Z8KARziMN* zk|m2cJe7qZl|)wCTO9tIPCrWUsp_(psDY{6J0X3YGtHr2RZZLpzdmM3ix@WUz7sSv z>(&Z=g`uP2=Q{l}g3lP=xI`bP`Z3}z*e24#hPuk`at=?WWkg$zWvxLLD&bDtwjmLbcMoF?YAhA-Ij7}AkWH%7!4lX&KG(ttZ>V^m*BURzMZ zkY6k3Jm}XJs>x@x1hr4PrH6bxked-T%kUeE(=-t1vA<2_F+#p;YaI);Ez0f$^qSnF zg0E+*(;`lG_)$ZTS3RWosuK=sfSvPV6((Q2DKcGg|KD+^&4x?18#JSZ)seqDpor=xe7rg&7aNZVJ~5 zky{=KW1xl+H3>F?|M)w5m|waCHTr)jnHfozRbZoC7EF{i_2~9}18TPl6M4-X_8!+k z5OqjhSqYgoo zjt0~(Z3?Sxf)tV6%&IJuv2f!}-*zGNcO@G35U6I=VI7DsQLqk&&4Df}R|~XZFlH`d z_0@BQBx&(DeOte*$TBw)nSQCcbfAoX+Qj`T_H|1^m#jVfr7mu0y!9RG1oviNj;~zQ zx;7Kf9{_tJs%nA?3u%J>oUs&73nv%GTRiAQwFhEb!s=?jj zMqhY9!nk<4#QmstH&0n3TMp7=M$3fXNysirNt3+4hi8UI!duvF+ni-Q ztXuhd#dz!c1{ru0G=$f_-lXU%q62D72snuqsFT=jWp*|HnWQml;1azWB=BcI%4gq^=oc+`ch^;+?UJhM>v}E84z`azW$SI@T zUN90@r~ua@r!LA4lVwu8t5r^p>ryb*_x&u)wW}|mzeW)m0X9U{xL_YH8^F8f;xs;1 zBYxW)OwWa7uBVLn2pDk*@4)lJo3OeY!qz|4@U_#F?Q=4B>y;bZV^_A8iV4Ep(@mLE zif?P2!)||pO;IW&>EjM9wLZIX!Y`fZ@5B-Dg9J(V0Z_9@3@(_a=DAOuPWCCej1&hH;3=JiKV7fv76vwGjG7Di&QDn z%CLm>fFLil<6z5`TAGK+kP^5&k>C72Hu=5Npj|zx2r5+G_IfSeelb|l{4S^L;9=)F zU$^&qN}uSty^EyCjGUbjl2@$lRhHYcpa*2yaKet z>ommamP{$?-XK{QTzSc!G$`W~en5(kH#cVz{kMKzr4FmSL_7AQOmNb1BqZZCR^l$( zJkm0f{EL*4bw&ufCrn1dfR^YONN|N!*UG*qSzZ zRZo>!b}M?%;}>O1GNJgTF5-^j-^D5Y?wJxwob`38LRHTH7F-XFiCTNgPw>APphQrA zZIx4>3T`QhQt5JHTYD96wSp%A22iUTE;iu@r9e7`IG-DG+SnQt~ z3Ey*xmNY2?+n)iuTC+7{4{?{*4XE~)+esdEPRc5#hHrD~LO&HT3~OjJ`MiE4P|AwBZpO!{1YrWoeN0iad#F*q^8=;7UyNYw z!Yzwr3SMNdnXyA^721@PbRmE3^*Q~fB}d5`Cvd%|?+@=rn2MSx$5gaq4q7@${lrMz zC?Teq@Nrl|4ydx(~N4sq`LIv_< zkpy!iiewc;*h<7`w?on7x@>Z#Q@779_Z=XVWs^TtiF4dBJ_}!B^pq>f*tWafPSZVi z8Gv@G7F=|djKS=xuk7?0W7VKB{doWbBl(6Lm}^G64*SjF=naRuxO-%q5*9OE&t<|g$7Bvc5RMbFGQBi{{u1x}BRJ6fT zi@wwbL`z$1qeWk8E#EiqAD+VzI0p_hbDp`M>-k-@Kr{WaoPG(V8{})h?y9&VX56;X zFHO;#b+jai-XfvqEN*DFt}pSQe<_Z6m&|C9ZTJgYuwS3`9H7eq##4Q?8cHyC&u zaRKr=7H|t=@V}wsV(3>G+%qE~ONZvhAmD#8H800}7VGFuI{JU#(t2zTEF%=I&;B`; z@Uu$}0l{Cd7f{k@i~GXPAeAQ!v=?If6>P~F1YSFbeg>tVku7f*IN?WG-hg z*7F&MBsA&L`flFdR|xA{!`5#lbi0UoL`*Z=#U>ju=X-XajMm3vOc<%hfQ1SkL-9kz z%!j;;qddAoyyOs%c0>oZI+It;f+EC76rKk@0M5n{6&{p|5n)q8u?geqfho7vlDxR`lUc6a` zGneEYcE2wHSP4+RgV#C2m2zmim6(CmNo=Hb1}IrZPB7p|FH~+N zORP{q4duqwgh~KfXM{>6q(qeHhpA#waHSD+)6YPg?O>dQn5-ksvcewf@+h%aNCOz8 zE{%`@^9R7x0^&-PXf=_$zb8p_q~~nBhMGY0hn8c+W&IG}4&k}Z*f`=U9(kLAl5V90 z@oGg9GJbnrGL{)6Pj^HfUeTh zcWUV>9<>F7TM+trfPS6_ccb;VdV2Al?vS2-!dAcUe%-}a4VQ+4ejZ=H{PFtzcy*EO zRIlA(wSUlaF=I(g(13(>PtLID7?;0b-P~-+!y)1^lvS5CibR5N`Q0eD6@MVa(Gy8s+4c$WUgNabFJ_VUsZ~ zi5WqKj4vp(4@IKvx?N+_2Fu@u=ryG$AQ>LhkSuWfB6I?-4pDh^}^J>;kxl zY*gGI{OU())?MmFl-YE^^|G${h_0=h2lwdcZC?|&IzuANd(}6E*w{fKZQD#T3Sk&A z2HBQxC6!{tVuVs|IFTVcjvu~F(~+e;Zrg2?vwO)+FDP0AWrvYmEg@G*$!kpHfs15Z zalZ8f>Hbxxy+-)UA!?0vch6k0S#AS9;q0uYOPVg_9!^fv?b_kOZ`35MJw<@DRUC<)* zlo0TJ3|BoUVvF4RK_Shh7$6rLNV`Qd=3h7vYA3`PZZEe&Dag0D_S*0OT5eq>#H~d! zSi~cy$jEguVEIRai;l!yC_KXbhEFiGIJ(R2cQVmC{NJ(md&^@@F=s;;1pR zGKTu4nSKJf(=5Mq-nj87nckR%KWu4>lyoVBZqCm-Q*rv%tFYr?rw4e!FL1dsmU$?d4AI)c8`;TC!4L`%9`So*B9uM3h zPTy&8ILm`~_d;iyvVHA9`@3uhIq;>PIwA{D0U2kJ%h?8MUcKuX3Ek#4_=|zsj@@<$hql${c$S z?4uvEu@j%-S2AUxf&TtFRap;)OQ4uZ@3|7vtlC3)6_7+i6n=!({sqRwK_zyjpAf>S z=`1U$k_WGdqu@|*m?K<2MBb=#T5p3*FUfKv4no3>GD-<{d$$-~jXKY`?JeG^C$E<~ z)yUyp&G1eu+{E+Ph0>n9fIDP<-90W!l4!7w}I zX$13SB>xC=fN^15AYR3$i9Z1)VaFNIpiDif$OLB@LYa5oGbj;-%ECYnP9N9^ zBSqk;`a6fMT}{I?>wc)EcN^)E9rRvcVeJ6zu-*@M*Hsvtjq9VM zfvJmW*Ll?Q_JP}S`Z=3J&f@&z<@8D0&>idD0Wrf=!8jfmymNG-BX+}mx26~NrWazy z#Y3~2Wz^^6fd(C|TcTA?;9tLSkvaIbk+czT+X|3l#l91jr00%cyOp*ZgY`B_uO1iL z$$sL`@FozJ;WGz2K(J9x*eHuI*hD;OkyBfBj2s>9+Gb|ckag*QC|6|;wbp>&qkmifLLTD%#%_2{&mCY^?Jlb zA|`J!kn8XZ^C$9MSGcSO-e@Cl`--AQ$QqQYv5|NE<-8T!RIVeJ0OZxOLo!MHe17LT zH1VN^{DrK0r;Yqw7*&N*e+2N&mQ=YF&Vk^~@~`%y`0f>KeYOTH2vzcEGuu#DZJ>N* zAm)f^EB~Ub!IMw6>+jfpZg0p**fcI^Z*(MAO2{80i71K{@gQG=ZE_=t`qn!iF!S{U zZ)~&(A(!e10xO8_En5SS+18&54PYz)b*?0!cF4^Nvc-gGBNWmAb$*L8QsjKMpMQHo zC>E200H_!x_gy2VBYq;3oT=|kwn8f;#3aeT#ene1F;cOP>nlFtZ@8a`2@~~Drj=}l zJmfs;qX(JIXlBy6ic7E3ieHYm@L(?W^q!EAjal$UJM9)3F>tf1@Uh~V`Q1$(m$rzP zom=s(gy8kdA77te>8f~psAABI*N7@io|G9z!Q3+AV&$jcobkEMe@TWu*1mH$%>QY? z;oT4P{5h(qEe%jVEdC~ab%$D8dgs)BZt=69z#q7a&y<{1|H0{j)J+UW-;xT{&E%KU z7M%OcrE!G6@ZeG2P|N8(=St1%NWWhUFNnxFX>cfd7j_^*`YQS5p`{N>&RkNJZycZ1 z*vyl*CI1i*@~BNvbaLC=2cHh$(#ezrFHC^RuS3H zEGciC4PJY(C-R(5?Kp3=_XV9dinl|gBbCIxsrtX6y%}%x0Zs95GRt;Fypp|O#f%gL zFk+q*r0kmeSQ5aRW3l|jNEodlAsJwMfK&R6jOa3J3hsB92RgoNPAA{fIL_y;?cBf2 zyK;g#%P#2Hm7d^HPETyb%=K~4Dmv=@@{&}nd7rRxm=FA2R+nTEKxBTuM8U+_S&)p( zog|PstMy~5sQdzq#HyaLVEA~ifS_VUu-iwN7i>bUptY;7DdKaNcW|UyCnU${{%9w< zt^mgBJZdM~m44d_2zBA-;J(nv3v9ya;}+Ycx`IUbH}AX>=AgWvf&GoTcL!1qIIR^ zZ;ew)6aSUBWJi64cgx~;!AENrn*Y;!`_5aEtaQNR?$>4a_x|Hu7VL47Q;Dms?)j4X zPkUz1SU<7CXA&xYYd9g|PI!yGhhDe|OV-|T&6o>>tDK^=fbxPjY-D7&RfrJbA)ZDT zl9@|t4&f=xEx+$BG*j~gxVI|JswcHAZ+J@bxz$kUuU(!~c*i@(62Gdsv^@U&fzn5g z{cTHj|Goae@+$g?=FOI00}qttd+a&-*c^J|z={batF@gHm!I1tnOg033rSb0IgLxK z0qoPgKl9=^jFn=Mx4}qGQ06vd>qO2cr!~9^4zmjHg}L_+F+OgH&R&r@Z^SYfRpEGR zg-+$Ny>}p9DW&J%EOp&CT^l{zM=6|Y_j=Vk5Lu#mmAYbF>OZh^oLtZBVhd5OxU{r+8ekYG)htaGbu+HBYWie2p=adi(r` zWt2JgcE@&kKL;JFjqSygcgB=gz-(<|4d_|mMDl%Oh38G2UbPi{={bqOaiavs9eT)P z(q0#H>Ty<`K)oN}*DV6f4pq-y?qj~EInw95nygJ~m9mN^3I2g&wOg~hDV2K2vBnOSe<7uZTL6~CregOmAk8u; z-FWy~*9wp$1xjWuTIZoQu@+>CoLcM}0nDK%8*!RLH>2^G!r&wW+hLuJ;M)=`gi8TqiI|KW{YE>B{#jadu*&$e_CgWw@`+^2R2Ia|e? z0JO=##;=}Fro7A=b0|o4Kstf{WCBt~khS9JeWyM8WN5Be z!nmSwZ$+GMvoZ50t%?5ir&4znPqRR-*gb!t((T*n+EkfhcWK@qtUVH9=nn>sLU@N( ztP=$o2oB5S<6&QoJt7(Xp_%gnImpsSfq`QU%SKCCJ25l5SkVw{(^Cx1XN^a*4jej1@CYd>?ls3pahRyGvr zRL)0u4@1NOO=WGRZYFl^_MxvE3vkEtg2jQKz|k&kOYvs;5o*+#ZHO2RSa z={|G5jNC<=6b=8nmEmAPZN!S9I$oRA?WUfj9)C+Kly|r(3P`K5x75jY$bVeEA?kBQ zmk%pgHztW>HA5^ zGP~b%!OH7>2i_pz7!h=lSExLn@ahuqcCO96W}gvTUEU$n$(?AB&tt1L{MRan)ok@E zr?++d%=r#0*~9?$4_&3t$DEsF{o(5NIz8+{a)(&u*TQkg$Ao0AT|w&wNDCz* zW|&ds-!=)0V3AYXg(075?96hxqM_diH9nM(W2}TMsTGQm&nV#Kc6U!J9(_+(6)Lfq zlY5SPCvhtc~++O zYx-h&$*9svxFkG^48n(ZEp6K=0GZDlhyiFjbxnY0)a|mym4hEPd}||xSMcE#y=Is5 zdJV2LH!Nv24}8TVMP~A;-}U)lQDV1*_;&Tm)_$+L7inR^V3WfIv%CI%ZS;m7(@Ocn zkU5+0M1Nv77ToM`-y}(Czb~mhFn}Aas3~%ixqI2Cc6Sp{Cn$Lf_*wMC{zE@Oo{}Bu zHqas4VnT_EU7K|sE@y43fCDHK-))&;9g?IulW`dYrN|^2ijz9WqABXtyt9z14Vr8L zsZd86HfV%wwUk?%XCzgkwLas--`VGU`z0p#*w#3MIH>A!UAuh8v zE`{3*dmV}dlx;jxHo9jqPqSH2C$*~!Eu{Cdq)Kg_5UH!eyjFCPa=Lao|0j7_fo2&- z+19Hm=GLyjG{somVwe}bPLnU7K#sLx&7?|hZE>c9td}gSAUSxbN%#G;U{Y+GCLX{? zwoo^sfewEwwl(_Nl|*ruu=OB@#f64K0>Sk}8#vbj&Feaz zDBc&F3FQ~H6oF~!x{I?tNO-gdS2^*;`W&pG?6tqRLL){kF?3%C4wVm z1b2(VTL4h>1ixNxWN*BW;8GH}&72v7A=YGZ` zJt!q*+3RuRcG#KN=|Rju)ypcBNz*qk{;Eor5r_Xd(#26Lw2?o2Nix)*Rkx~F8W(1v zoHi{f%SJ3P+`=b{KmA4O+Y8rQ&f=!#Ioip202C@t2-kz57ShAD>SBOcXfQA1k`XQe zKeFGg!uxl=g<<(@Ufnz%No;A^?4`~I>Q-<`#WrFHP*;Fzk_Dt}g!oJ5#R4v+NQU7{ z;cTR?u1oV6qErKQ4+3iE&tT{yXEF}fC3DIDV#o4ZNjRm8lc80u$BOiIt`{{$I?~+m zTAZF;fy*^0abi9ytpPt8d9iK8el?PdD>zFmnoj_J+1LI})VR=czC`L)7&TwY9JZm9=!T;us18Hx3gywc z617wdB_pIoy(Ae@D~552R~{e+qhPhKK!w5}hYiiLD}4~v{zgTj-7$#^@nA?OhJ4sr z_-BIdkSA6B6wdp+1!60J*?W~s3PWjaPom^_=ZSXB@KMRS|+Z>a=t&ZH?G>?pb) z_f{43cjr9a#Y~-gupKIExflx&MGnM7Sk1SqbBv@6txNo0M6q3sSxDkmQrlV$IZ}g7 ztDU{ni+E739g4R>SNAOx%9IN#pg;kH6A`{B$Twz0jSwQ`kn6P4pR4rb;x}O|N3QU* z6CCUWKLfb#T8Nt*S7N_u=#;zb37!T;q@Ifsl&%60>xHWBvsFYTPT>Ur$`$cSd}Cuq!CDDgM`L8$s^F?0OG<5rQVVF-DZ4@s7|i&y#4Qy z%YtX$b!tlFq!nz+uqUOIOZn13*Simm5i&F>qpvaPm4p)5^(mIUrFlvqQ_+xf{BoUr7$-;ToPyxm0E-v=WzS zHBJ*G$cJ1Ble2BArE={`gJ!XWAr&i~6KhAgluD!K-(gmjl^E$oF4U5`k-E`d%5t_J zw8bjbQ+}8cXcmiE?|*8{(`%$W%CMDK43jA)^7I*!*r0(*Od%NgORV-hwKmVHwp^l= zB1D(ph{?7dN?6^cOvbF(ItgA_I+t815oe&d;jbAe(EQFJV@O@`G&yQeBejs`s>ARfACl9Y?!f-EuD0L?{6)G2kYU5)RD#TS6_@`W*W6+RD{rDey^RB^JCNvP73 zrxMtdG5CQJ0fgcmgWCM$UWhFQKO;fmS}1iAoXu0kW-0}CVxl!YVHyVw)U$ijVzsk< zGgSh?0KQ9;VNpg3RI^8v^9=ifsOnhs>S+gJ<75DH9J{(n03?N5v$z>Y!HW2VF@o>Gur-0z6S48n39$7sbfTG!u zZ>!_#4+Ktyii3cwZm0xWRg3}($Dl)(oWl`9tR7TK6I^u!LpT9%U*NEnj=g{bo4~iZ z?0E{#LW53zWS`jVYO{X6w`WsD^aHv6f*04ATJH4p8+z&t8Jdu-~$E8eEfJI;l>EYu9@o z%aLdZQ$*>s+P9emIT7Kbq-0}l*$8n1udYC{==ZalfB5|OI!%=zayh%Uf5;T#S-W^d zgVUY@mpUY@Hb+3Jw737VdzaMe2fZiyr)U;W(-+G~#R9U7d+a%<&g+B*NvX}7)NHmi zzChGPGgfseZ-~5>w^Z)S`42HfOUmHJAnW^p)Viu(%?gy9WFciwYFz(An%k_w&(L*S zG+zQ53b=5YpeJ9ac@m}(_LARx`XR?sCaX|Oap@dUv)OUMI;isvG>cmnLq3ye2cxY} zK-UT1X((BT>jX-HQ6&&7gZ0oNnkrpPNGez+pBUjI%>r1xnW0Xf+@AoecHO`YQdNp1 z&L1v+b@&P!gW_7lu|iy^4qpwC;=0spT8P35Rm@0mf)-j$CoX!N@$eM#F2ZMX1iq9a z=C+b(&7|BB%{)KiSXWbVE2(5e6^t%P7K2<&x#5sHq?I_&=s1&L1^p)p_~uAed>9Tk zC)kxyD7gBG!jB8?YEU>Hx0j{QeD}v+VBw)YP*+B6Z6?K#3jFL9wm}Qa%Bi6^nQ38{EQ zHGf9cWZe#Ej%)hq zk749wiT=^=%S+khJfpf8t*f*Vi))mB{H}f-L6S`pt1t~z%3Q|PJVwZiaT&UT^wy#F zCzP_GRa1s5`4y%j`1_)Qe{W(ITIp9*b_%7=hr)tJ{ZAs?>_n6A2KI_gK0@a@qVcGcN$M9$k9<#eA_o0sM3NZ)bSVVXO%pdh0y7 z=Lps@*-abHUVOFqSn?Ln*=N4z$u1k(oyos>AIra${6*X5)cS(~*}E?Z7mqi&`K9cT zKYe?rQ5RWiDgU)6rBUaxaz*Q)UNeaJBxk*r4bzu4*H?c)O{Rc*3?xw}@3MYFcfdm= z+t;Q6;x*t-8C$?+d64ZC+?33!gNN!f65~& zDM~x^_!aB_=B=}QGG#7pYLyQ7>~t9)@RaC;dXJsmxdWd63Nvdx)^+ue{nxj0>I3%W zjavfuSHGY|{Uio68=8S}GP+Y|(s+4j2}5qaxHc&)$pkaYKWpnz2dIZ%-SthM8>|gr z{ehRcs04wqpAs==pemy1c*^#vo(#Y3)v53Sxd5aGwJ_dg9XyAT8qqt`;|-xVr$Bn- z%?&)OD znsj^+qEIgx^zIM`Swi|;j2e$Bz$w#f&E=s1-(_ycXHvbch%eNIp1%pwg3iPKnIY;n z6U8Tfm@|NGj;+dY->M$cc$8(1*Sbq5_;*9rk18ku($UHcf5~9(6ThuB({*0;qdk3o zi%;E|AJH*7o*8hpM73+l`m-FGzoKRDw9jg*qeZQl;AaNZ0OOQ{d)3N^kt)G>X2^E? z5EVTKe@}}Wk*gM*Y1{rdRkSgd|I%r@AoQ7gRTE#$UfE*QxRq&5ecMaqW02$4^4xUy z!jZdu?4pRb^?rhgvAzs-Q%@bs$tZ`D?FzSko5uGhK#o9^Zla5%XuAXt8kvzdcwn@p zow0(O;&uT~HEky_62?R|R=m?@!;FHQ$7^xBB34G~sY1;Yo4YdTVcf`wEO z(jSls^m|n650~VPGd9{yo|VG-I2q0dPOEU0z6|{D^?~Up6d}Y@s%B?>+m#y4%5~DlGVe0v7c_$}uB0tl80mb- zk`|eVkN>t0QpIvnPz7G>G$?X58xBTn`#@WB{Iv5CK{^lcq!)IJ*cV`OSih7}FyZfw z*X;<%J=6`}DwpH-)#0_3%u2~O?&D}(II)LXX!*NI%>a^@#szjbG@D&7Y@x-`uhq12?RKQ|E$Jtff{@-8g ziqe|`k@_Wcbj5^rlU%g$`Volp4%0518h10}yj2N-c6Fl}p9GqnC*%u)+3gM)R;Ak~ z6!ZUthoWr+)~nx1{3ai3!Zd^n*WY^^^>Jx_qCf z=1*c|-Xy`PS4O_Iiooz1wjsk-O^jMdUS6ScAGSkBi~@$b6lx!IKIB*f5QAFV*}cHL z7{eG#607nWM9OEn+a1=)QgrPWavbuO)jiYX@A254!2HQ15=U4)TdFwzho$S01fqHy zmpW53nX(aWF8eR?%^cmBgB(%KziETfOq;`|5y-1RO`L;AV2cdT{QB*bz?whYlx%hA zN8mzw!Up6kohtbDFb!iXkemXM^Pw9=QG|)#X4M?@zCz7k2;#2oo#=H^+LC_I?Y6ch zs(LIo!mq-ixOJ6Ff`pXOP@k!;0bM+U;E+@a$$1(hYfqgjdn5N&aWx@v{77+&((QtP z6dEU_Rdse0^#GLks<*r9pZEFAOx{hZHc_OAigl|iE&Pq5p;+GM+N)f~-`228(mn3m z&r6Fu*K=;`P@lIwVdgY1;u$qtkmuHCP0gRjr&qPT@V*PAMP&&a@-~z@0sn0cui-GN zaX#_wa*JZ_D z^qjQSmgr5PY3UGJM4+XscQd!LQ}aB>;Pg*pZdWbVKVFL|C&ddK_F4}G{${V8zfk0~ z*H1O$mmsXW@vt4)S8*ru(J(9?9HSFxB(!vhn&02Sx;VY~?E!!^eos-8K1ui+YA3%P z!(iKAifM}&^sX3#6DJAo&#bEG$svjWF#9g{r)A*WTB{|zM(y}o?s>_?w%=v9nY;mu zH0kuf?S46gpQC+&hKHL3|Nv+UWuZEVKk< zDt{-j0*t=D#e;0q3RL~5j_9YgtO(K_bk7G}rhr2mbIW^Yg#QZDg;RhRHh{uU0}(K2 zAvo)n*)OI!fT8r8QUtsfmJ*cyOwf}GhPNuwPh)3>l%YEC2a^c*??QE=pc7!&AGo{t z%tL+0nQuOKB!kjv3j3r;pPSwco1Nw=LIn!2MNTd;QY80@v2sGH5!UuOwJ9QjMSeHI zXaOPGs&uYcnsCG%Y-|*b^(76b%o|QglBXmsToTdN#~U<?ntLv(=Lfto9|eZJb_}QukQ3l(j720cKLIQke3Rm7xq~uoQWz#v{*d;Gj-X3_B(4l5)~Y@a`>)vY1mxh2fbVk_4sS#Jvz$8Qg02tucq& z6*JB#p=(Ombx(){jw3P?(nerS0$#>}-?fAduM1cf>_?Q~$RIFd6*y_(s0CLe(awJ3Zo3a63`HiMwvAjHq! zZF#=btpLvx%3)*r#2pK>+_dJkars&kvqspY7kd7WWNn`$eixald`()$9Tf!_S9k)2 z`+J;Tx;nDuggO_?PJ70>LZ*jHp+fCwCxUEErae{^G%5-bJn`S5ZJZwSoy=)u%>SyF z5adGoC;NbxHaAyh`ZLlWoh@WL5nT|xDMsiQ(ShCsXF1N8!?@1_BkF^B&fz8B;}?%J zFF-aR`X-5BKjIK+{kcL4|D!fWYy<+Z`2kL#)3hSER}?9`=U8DpZow|%gSl9Y7;A+4 z`O}`W7s8gwxVx# zlhf_#1q+@niS|wjoe+gc%sI2Y$M=f(qk)meCA^YSTtzDD;PrTS`GmH;?DdRpVG5hD>4}>MAPm-$C1YXi*_6DA+6&};O6W}Q>7Q31H)inAZ+ z7b7tT>XiJEl&&jEU&dK{4-fAu`w`f)_mc@_`}t(v>qsJ^Eg;8R;mjkPOzQWr$=Noe z`d=ZdRp<$rkY1CUw?d=c^KW)O3)%aB-59gRv>tAz^vm6>T4c(|_7;A83m`fiKF!>_ zN<8y-^)A1WmE>!WakI}66~1Ph*m4De0|IEU=LUkO+(haWx`;pfqJZEbS9s`5ZskCg zI-2k|`(vkR9kTboZn>Nq&TAi%|0iK!syVF{XLA{R0Q-21H?Zc zB0chJEA-NN_+2)6Pn!@}*^Hab?0Ub5XX|s#*lQtfNc%BG&aFey`iJ2NXg8SyTfx9X z%E&1I7aIyPyUtf)=mwE?(uCHiL*e4+Y>ZrvMeP9i9!sDNSTbjZk3Q9exNU#eV=#H01Dbv!4mV}7cZi8klx4vN037JWzk1s58_2SpKDMhG6gY7yl;;)hg=oTo&0E5SHJ=lmhi zz*`WRIZMDDEqD*!-vz~YE5k>__>;;x5sRV+%@MtF{Kl?kC~#iWr@+Lp%Nc_J-3aVB zqVVqB%N~@o5z|liOdhXI9y)*f!kn_1`>+Q9{%myGw?*MeB>79U%xjt9vR_#b74EU- zMS+jJYcvIV&9MI?pAw<3)#Qyk#Z$XHafZ~=zjLm_=fBlqv4hTY70#~-uIRX5zcP@h z^aw0v*M9G375cnZ_>3qmgxN1c6a5XKv4a@*3|YlK-+#;catt=PCVw zXO2yN4z8_k!`7cCF`w=cABo<#&4l~CIFyWdoBSAYXsRfv`;IE}puWZAA+b6l3inn8 zVl)1=TS2K-_#4^m3Dj$D#rqR__6CBx-sYD1+$Z_zn(btj*rQfwT4$@6+IHSuk9?F9 zTrLor|M+vcr^n26%t2Sa*5i|IF=Y|X_%@ubr*4KF6(+aTKiw(_o|lld?t5o^MN}=O z3x%6=72ZA~+wjkye<|Cn0RMjYbWx^RN#lnXWu;WW|-$E6CG^>qub1+l7^@T z;6y5LG7F2AnZxXah*4z(B8tds8hk7Ka19Kr5b@wV=Lz#%XQg0+CwEG8+qNLeUUpj& z^W!~HB2yJ12z0h9LiEs&NBZK)4qV3a@3h@>zrpD!@EQ#H$y4S{tKx@Ju0Kr)SvZpy z>Ee)n>s_d>?Ayzy!iej?lRmIW^7=NKBZkePKxd9pl(SZZ4quC6pSjy6$^}pIx zMS=Y|F7DfJB2k4u-61fBO)R#_W5h%p>0VoH@)pE;445E!< z;+N(3?x&q|6;zpsR48Q7`__`n7N6gjo>>yQNWYs(4`%ES&OTZCFkws9QncYN_D5Jy z7wg*cvucMIp?jAcTYf1)LmS9&4q0Bf)#0Dv>L1SJoQga6Tgw9W*=CO=N4)}fx1CvW zVcW0%UkB*E!4%03o_!fqBz`8}@D z+QHaW#?HULNq^{*y6Fw^`ir{{(fPXsq$;RFcQ{Gz{nI9&^vj)z%giPBHisg@0DCH^`kzB5{&dGbqM4$Df=G@+-(Sj64%r-p6nIz*9>f;70UT9G~ zx49?oKCi5a@#UnD&i{b)G3NCH6HM%rz}y%=*xT2bIAQpc!Iyu+DVqy;!ZOy_`%Ns* zim4iWu2@8kkzF%d<-N#rmrCium#s=hv}r^{k1_%Q4`QzXeT}o8N+5@Xp$T(V+!JwG zPTZ*M8)SaqP5bWIQJ-FzBQDg;qO_kttKlY?t%d4Qr}OM4mw3`u&R7MN5-@)jNMLg*any= zL3^?Brhp3qWfA(RbptCN{PhJRV^B813KxfG$44cw+f|WSPg9(79G@xu1BR$+PkahL25>=SNeCw5@|>|3zvr)`*19ceI<)G*qJ+Jq zw!&@q?l^9EWFnm_Ua@J_*DLz(nW!Pt_5hvWlSPs2zXR^r-Pa)YCv$2>%9?oUX)~SE zFAz2HJ^*FA60dVr^zcdeo86qj-tqh==GL-Zn0ZaV!y9h`v1OhaQtwT4+R(puW;U|6 z@1N9R;L(hrR~vxLyDrO_+bBOcJ9TfPm|=EQRX{=9Tq-uDaM~(M4GAQ}SvQpmHGU=- z=phOtOmGsr57{OmqOI+$?g|nLwAVR}Drj@@Dzhs#$ge^>emn`NYz88e_wV(u-M1$*XT*RB&Cq zF3D@d_Yw*!t~(uiVmy^|^mH>GZu5VF1;-TJb^PPS@XTQ=X@1fvVZ5Pyf6%SYGZB;X z>20(h4h|bNvD&cw4$u8w$PgZD{Vcy+ppJk3{b4$P+-+Mf{!e@>;di{kM%w>DzX>LJP!eFr4~As8`)r%VF~ zJ8D4ZRy}pOZjD>Qy*s>Vq$wG%LAv187O@FGGi>N!`K~RBCRR4&k1Cz7$%;4D4m~b> z4Mr;UMf}wqSU587GQmxYaw(wmoX@8pZYy>EUxYK(M9;>h*^~DEkXlY- ze!r4=bQ1F0BxHT3coBGvUBds!V-zDDt{1hxhgV2r@+^I_q`XEb=>{d8l$La(tI8+h(w<~o+pWAL=rn}?-2MQk% z*QQphIiIeiDC|RSx7oFi)f~EyPDJ!1(3}(gNy~Z-ao`?4Rg4kuf0SQt?~ zXyB+^Nlg$sZv#jnt%kZ4jBi+9S~b4;J&lsy&yi~*B(8C+u1Es99OByIwYhcDYLUx= z_i8Sd+fd4QvF@hfaPYbzO75s?8?em@JCB;v2b~n51 zXYux?Mds|KT|GVSaz<;3usFUiZ4t41fN(xJW_NOTsqBg_)bBvn`=yz#Aq9?ie*E0A zp>#usvl6(=i#5?IU0yh!1L__(@TrzVX4mt^^vF$|B{SL8rEa9rn&@rq*H_42uw7by zi>e+{<>gRk{?vk5;NbP8J`_RE7vlbxBrf^fsGn@@tG3@QFZaIw;)yUr8NbHwNeT zsyrtR8q6wmT8jsO&*|~y^gYJ)`6*sCeLliZ!ss8q*fkCAt&Qp;u(rZ|o}YN_Nobv; z3P=sG+w0%WaGai-98qxS-6Da={cT9UpNvmk)^yUPw0(pt#kEcc{J7_KBRqjo(rt`w z@g{jO_y_|E|`?PiNn?sRy72> zs@q+(V=fZ|QNap!V2tZRAt#^X-7jRzg8)}$@6#3YIkYF zs-jLBGNmDzir*FsT`O|j`p0}NLT^dwrlO9X{0H7ff(uvclA}1q&2%o8x|O#(7b#9g zbU3GW_`ETBN!nSs-?>mBOUAuJ)5cWkpW+<1&!ui4$9*;46bzov8UqeUn4-)~@&e`u zEVxb1BuiP-5~euXuTIZAfKiqBG|otuqs*)g4BI=#(Lh{=#f3HY$nbW2YFR`7QBJ#n5z%}55_*uL>^yN6U(gn~1T#VJ|C@Z2+TMa@ghnNxIXpCeHS4;8`Y_$wD$=4Gj@vMP9sYClYwgc zZbrPBo;w*aVxptIV2|lY>zBuPmB)x?`g@t(s5N}+DcX#Z_J+LxkEO#cHwS*D+ndAR zSYe)OCu+#f4VhOJHzPONh3~K{yqJn3WJEq?!G(WAY8YmX1h!u{?f{Qd)VXl0{sW!mx=%QH6~bF8D<{(kWDx74c_EW;Z~ z-zzC+_D6MFou66DCrsc@B{{1Qm<`RiW|CgU*v<8t$Ul>C|1*dM;k|d7`)XFI*^t77 zuNPLvD)D@SyN3+ht)wfh+y9kC|BbNHR1R>^)N7Pp&L9n#9TivGLgnyb1HHI~HH~q4 zmsVIZdHLC8rcAzLc8HyZbr5h+7tnQ{_`ou(t3>5^{;^}%>^cqBX%oj~riE!!d6r^b z*S&EJ&2Y?bWqvd=g_D1FXTdvKc&TRY$fTp1<1}kvWtyGxf|x0)%_WmtcTGBeC+3yp zIGNrAw9eazBHKf=j#82HacI<95TrAbLI{`eNBfN?6|Hxqpbp$p<@&Wc4z9kkb0hZr+GW)fwE z9o#0J$7JxPN#EZP`Uv6(Pe!~n?0F-j7xvptA@rGF(}^utv+%Sh;iiwRCTr?{=5lGz zkUgGq8c)9DZ-w1_kv&tsK>CmKE{^cneupfTeGUhHw=MJhb}DuMvJ5R{4q?S3lxIp% zZ-Sl~f<87Hi11?&w?t|E|L9rSmyvN2{Y;XMB(Z!UyddMv|Caca>|zk#Gep*t><{IBKC9Vcei4qp7sDBCVTZSNz+7Q2dh)#}&rz^hF0B zj~EY@o<6vDOIphEI_Hhc)6Qj@r21`v;aA?%%hT<$PCdq*JM7z5cz5Uhl4ap8oYTK; zwv+r5?zUuc1!A9~w7bu?@5lo^%g9sP?-w$uT?Xn3ABxa~Z&iXjt@LsXX;enC zKyy~;y#WlDWk({FUl?R=2rkA2Dt95>|I6_o3_OLv(^haUhrS5AfLK6}7dafZt}a&5 zS8-re`d*B6e{+c>;jHLXIsaoN#kBCFu&yVp*&Z2=Ba_*6SceP_YNEQ%8LXSvtMW!# z*x$FXN>$7z40Ffig-r98wMRF{Rcza7b__*HcX>M*^0m?2eNT6G9Tok_FBv=J+{Z58dMlGdQ4O(7vae-n#SIXJ7n z@Gv`kl#zHpFc$gQ^(n9aWW9!yN80~IA4;7VHqlXF|1)BzrtHOu=o5x7-r>UCv2FJ# zwPR?VyU5-+=^Q^v|BUPzF8d<5myvhXF1#h{9(yxhMBRl7?T%*G%O?6R#C~Y<+MEHz zm>qgebgOKk``2_p1ve|3KgmpWh6NRE7KjtsR0!XI4zF{{jDewqlKhQnWY#{v? zIp1=eF`!)5VK{prdC|!HCf=l-J=VU%?6$4NejewL=e^9U-JNA52IeBXJTBECs&nh$ z3%g^Ccc#5SNyIFaa$V0@baO>y!Ve>d(?edQ2sfsEcQuXo;K+NPo%lU1{KX!-p0sh~ z>SOEjtI|Wg(?^=0U!%9cg$XIAK)e4l)XnN1#h<3eEK4V!&amE%Xmk%>&$&stC*Bsf z2T`Hhjg&ms}P^QTGgt@ubJG1}C7wC6uu zbV@Z0Dll%*+P~82ghs3G6B{bmo;$(@)foRFIY~%#M58qRDp*oN(_gRJ|>oFZNvx^CRO>K2rXLekt+8iM|AJgLUT6G4Q^Z>+P z{&Oa`RzA8xo1Phw~1(i5(^CKfNie60FlO)_oYuj(_Yn zQ0dUVt|Rhd{^@zn{FT29M_^B%d*mm##;f|CP{uB0AJ<+P87z>VTvGbaj{3o_tObhe z(&@a1w<{0E?E2VX|1m$^5OaEN*5ZmA7^eHAeuUl~#6u?|VyA4Cr_P=~71jJkiTmKc zGAOom(Q&u<%@(iNnEz`l+WxZ85W2T#&C$PZ?{5y@H8%V9xupJBMDFL$*Pr`I4T<}4 z!SkkH(pdBB**<+Qpylk8+F>Pb;QGD!Na=vKY9Q(Ruvq!XniU(S?|X5!O}sZ0m(yFK zl{K*3;n%D4Q&bqjPmkiix96p3dD$7iQjdgdgl*f^1=Slg%%7J^+zQ$|6X!Hr+8N%itTR}mo2(E^beVT%nCKo=SeM|udTgBFp4VnB z*A?y7!f%PQt~mGE41m7;RzY{OE6vuX*AHd`87HK-G}C73EUE!2=pwmUM z77Je`>?{mP4eWoNlj46|A4d}&e_fp7XE1E9@n;{m9~Cf*GO7bEv$%J8eULaUK&u_A z4xrpAPVsM*#nnp2l~W}C1)FL~kip`8H`AbeL<#!#y_n)#+89zZzfVzLEm%8g>98*z zeo)EYWyu)v-IM1n@k(%&53n@C->c>?nMSmpI|3LIj~zR@v|iVNklUH9!y&_gIw(Zv zSMeAtVebP_;w|tKAw(R8=`d&+)$9Fccpr_UTP$RH8 zbF(3-hhslGNesVKw}18W*4Q76iCU;W9wik*E!h0NYsgGRj{{2*CGyW7wPy~ zb?9{8ys&(cPNJd3b+-iV_qW?Lv-HS>g&RIrX(yYq)b`kT9^Re01jn_EJBD(8r-%}e zQ|GqsJ&(Ihl$7_%zVWz%}W{#6XY;zW$}C57$@gwN98pI?0)BAC6Uw8 za^TBijpr|yJ;S%i?KT-6xV_3D1n66+-;nE|M06u(>tw=4%wal$!SNapS_Gn|h9g*7 z36kMyK!7!mG2GlqD93_lzl62ut%8&aiK#Uh(9wz^?biReU>`aq>1HpR-@e*QQdhmA z7w3LMfpgGHtGL1`r>frvBFY(6@=-A}>Cu3F$|Onjx=%69AM(|4YSO4&mlLu+Uv)jD z7@9hE0#gxTd{Ctr*Rbj!B7Y)`THF}ta6va1kXuA6HncilY)ps@GgmEf>~=C^Kw(N< z_3j=0%x#oJf79f>RS@2BSdtJ@j~dc*P^ilw`}VwvPa)Ydjc4EROwb}+dJJbIPR|Rjy~1siH!)L#blGPnuKJv?B?=gw`CrM7|K&%#g}B z&kNU5(XgY{>7}F~hm0oOv{TNS_G5lC((+#NJ6hh4&sl~K*Jt(eoig~K``j4w>78Ke zN^%{IuBu4nOQ<#0pFCb#K_T^fyOV|iua6ee2tiC!%;Nd#U(31Fk0}=pqOlNLMIc6O zx7I2~SG?5=^33FYjr|M&+k~nBq}F#DR;`p6iQ(+slov-|lt4Q7>1(j59o=|JXt>wU9yZ3Li`Pu8U;*~CM z(e25p{qu9B>L+?#`~*itQggiBI$Git?QkW$^vI2|_P}yUf8lsNEq!3X`%W@tZ9BtZ z{k8JHdv);(a_}y-y!g=5DoW&ZUG=i%ZJuAAOZhUT-=&3@!=b!gny`p*{K;69WUK;t z>#ub+G@l7?NAK6K>Uqvf4TA!8TD@$*`~MVeLDQ0oJY63#QOcXx+5))hmBe^d=wr1k z+nr@5Mp{&Kh2eWg1p;_llwm!s;LArE`&H~~KA{^)iHa15CZ(2z_K0l9;Oq~>_-G?C z2Sy`ibpA}*(U1dQH70m9m)Y&sw{CU5Fc=isLM#SyoYSNdufhkUH3O}VQ-<3C#VVS7 zY@u_Xm^@C)ypM*#&f4G6!%f20?<ZZiG*cr|=`?VAQZZb+2HTKQ#yvwyW&t*1i%m&yTkUwU;3k zvv<*4`ssxMlk_r$+`BioJZeBrQ?%n*&n&u#;YnKY*KzFo#`2>JP*1Ax{d`+CaSx zlD{PWy%I_jmn@Dax?h#(v{5z2$n<)yjP0EV6S*d=7feREJNs06h)Z9Apb>N=6EZKjH&6vB3=j zaD&orxAlLnW2ylx)>E3K877kZ*p#VqyFPg94{U0mf*z|Ww7*QIcaqLnaMKDp z^94Yg<50G_@Ok<}fi6a85uG z7f}rCqI)6GCFvK(tdPb?Qdz?)HZxT+a3cfi>L(q;K-ct?6al#3DlS$KWERf?c2t3d zs4@||p1XNY5q84JHVkQ{lxTQPED?Z*&JsH`;6{DYH&SpD8)_~lo{)m(uq;&A+Gvd2 zs|1Us;NoQP*a{GV@o`f8S^+Usir*$B)))v|+2}vPq5TJQJG;Qm2pAWPQiZ?{J$Ac+ zkYpk1pB&U$AuZbchmbT@=#T)~DkaL;#6^0-dMVNVH^K%s@tTad7bE(?;`PMRxZjpb z^iBfmcRCh>2aIg@oqHX+#CDtZ&VQeH01-Q=W%IRS`;fiAej%q5o9OjK_u_~xR+-yF zPG*MBycu4em?$?TNV^CzAH&k)8m zdZlrcgyS`7?vj}kw985EfSnGaN~2kxtJA{o%Xkkw=WOPGMA7(&ttNGGl2YcTn}g>>CU`9;gPh%DeOq+BqOhc_`c zZEWvTP%c_1NvB|y2}EtqKOQRBos`udK)wl#S2|{hRus#CjilL zI{#lXSSSF!-bK7Fx~Q%Px4>STEIuVtW(OYhd>UG{j(_9^%|p|}Hz?s_<` z%my)sHvWDVvdh@(q~2H}laU@C1&}iVMmVbU0J_V*d}L;^P}3h^Z1E{ z7n;*^eK)AN-{ktnyyww+)LsvB9Q!lv3SSgnK!EVL_8 zZ#dtFRll6aRuE%g0KpItR18I(vMPKuCSYtkkciph;|5)@wTcZO=f>Cj5lHV0X$+m~ zO<)iLE>YqXv*dakx%moE07HURyP6H)Ekn_dL!?33Hgss&sVF#s2!Dskr532)M56yd z{(Jdu+9`_17qI)^inHuF=(;TP-dN3)iPU8yp+N8z8SSivJZm7W*iJPgbo5=>ZCm;| zCFPcl)L{#A8!W7%h}vWrdD$wM4WZU>O5Z#rL$cY=72$VOYix;p$B zJ%G4CzMG3r2%w#&%o0>6#yAhDNXHfACX~Om9Qb-R^UFtMR5~@!gPLCi)SCWNFqm1v zfaWT78ko|mr|f4(mLZ_YXl$XpNPSfbr`7aB?e~ zcx=keJPG~%3_`KNKi~T8gGpXJ%nCj9=OO5TQ6?NX>LBKWXugahLD;v*SwE zOR-GNxlB4n3qR&DCEj9X=c84Y^>G^?`Vj6}3W5_aFZ1~()R%TK_sDb{spmwV)Njkp zt72_j7SWrLWyyJx6J^gkvLiV&Wuo4Thu~CMbb4hwEf?A6P>@^jz>1)hT zC{d!s$6dfnjE)Px$WOA^y=@0odOS))|6puL#Q-8UxY0y>i49T8-4Q{zB z|1ftuaYg;5h;ZNehk5f_o5fcj5Ovk%5`Nqd*DTgX}J(ohI%KbE}v7=E)#Ry<7$p*$CEMNvv>>b0=c64Zba{C$<{#vk|zZ9ypkdeh|cH z!*NkoygeJcRf_YK_S#!fdK2en#H+0M?ItkL0`69>9V!A4>)IhQA!i}9MCqLWl(1F+ z#>42b@0{unY7F283zW_PVin*H6V$52r^*IYf?tlq;BgaBtEjrNnc0eo_>4M|CMPuv z3LY#z?esr7o+j#t*8FmM5In6UIn|x<=0ME?T9cIW-a?+Mkx)G*=;xKBG{J(o=w)|a zPGGsq_6xRo{y;fs>AR?;9Ai`HepGTcjvZzXIpj&|-vq&Nh2g`pZj*(k0dXB>(P=egSH#59I^JcHj^0;L!P>w}e5I z-?o>lz%*B|)<($HM_+hOEV6+b_KxESczc7#*BA(Otq*}-t1b9lmY5s?ajoEXwDrZ; z3Swn05iN+xHR2Z{;Qe00S`!|<1+YOueB=YJ zlM+2c02L~x>yPMsfrV1yI@w$YhA~rnICwYa#ks$3FEwy-aYuvkeDuk%t>?Y9xV)w@ zcHdYdcKfesu0fk_*lb+{yQaXc!=Jw)U6!_l6~I5qwYmq#<8{u#DGitXr13g+(5kVS zUmqq!)`bT)UCw)usBieEd4#`(dvb|`&y5X^9ZQdYLHQ|ryWf&~#<0hZ8K*XFb-TXh zVb8f_m&I=3GK(>J^>!ckz}J$7-ofxwwP8Pm zJ~NB76%3Ea_)nM(eK>;Wwk1t-@pl37!IN}PL&fn)Ft7OIOmI$ax*@t) z$tr52d$Y`3J+A@NN=^FM6XzTgUU!Ots4bcrpIK$APB?AT(74k|d~IZvqU9js#^J%) zfNPtJqWIithlt3@Qx2j+KJnI@H*C;r-^|fL=hX-UY*?N7DfiGC(FTPX8lD^UQfl0fbCg|4@ z?+_Z~p$`TemZWU)(P!tJdq8rC{?v9*l;TgQh(2r@tBh$hY#>GaFA<3dHi}7pO(o)k zz8M@8-kdW?$_``+IWH_3#jf%C7-CV|p=f}?gCiOYakV)gFdq&|r=iT{k+^TJGNL~B zY@o*4X1(>XP%;-zBVu`Q`p4CIW|8Ti=T#vx3b$rqpXn5B(dXG#FlI{HSN(jKwWZqY zU1J<|!Hw}&kn;w9ULkxV;OfJyJsWB@)x#~2$7_i9MY_R)wtPja5q;__RYi|_u?bocu8iE&De zcwCn3GcqdV*(YKgI8}keJ=}W9iR(j*r_y5{b&`&A15HwvKkvi+!jV(@&>L(so))=! zi_efr_B=50NmlOu*p6|wmL_I$v~;%4fVXqjS?P?$jcrxFigUU26wJ~CFz^2O1|7F& zwqM7p9oM+N6cE`MIX$vbA`^S zoT%|6^p4h;;KzSVmJf)T*LtP|n_5V|JEoYk#@m7vv@YiT6becboy;f)se4}4ELj?~ zv$HC5R1SAp27{`wRhjXf((65C{9Pi*D<9`jYSjpG-*RKRWu)x}4SQBeTC&4i*VJS2 zY%prM0oWCLOxiQyrz=IWR(caV(LZ10Q#9S@cvjEz5~_yL+I#xP7QDx2Yej?(=UAE- z&t>1Kh_Fu56oQ8y(=f!H$ymNqQ1AP%q5?rz*-IK3Y!pgO4#h1ff%=)~aD_82m#>`_ zyABw%LZ*nj$ZBT7OF^$O8SU#i9iHyJ6AUVqgK;xr7LHR+XS6^8R*`+Z73V#`Q7y{D z@$U=DBNoL&A4z@o4VbC|tAXBQ0JgK_w7U2KXybUC`%9TNm@1C-_@r^^Gn9ukaDvw? zi*qS_rr|b8Sk!&KArcK zzpiAp;NUxli;4t)^|2J`laGhGJDmL1e@b3eouNBtATVl|;isE5&Nn+$lv26IwRyYl zP8-&4cQXD%$Ih((WNoFa#^CeDY!z+OxptX`zKowNx)iOqeXViSIrld42|9H(?fz*) z!7^P^{FHM~$=rZb7B2GkgCSgH-arxm@A0S8Uj^!~7lMg`HGh63!^6lgWN!?KhfI== zDEmkT3n12k@YXy!_r7s0CAtfokLmOH3~Rkm&|EA~({aJ&u}~{j)6*Zw?*X~|NlJoJ zMXON&bDkV9Zde6p!CGY6MBbStrhLctW{q*c_3iyE9bt*8MC@4F{=y4aw117Z)#;{H zdunmKF8y_@Q!_dr%zs-Mb+6AM%AROwvXNF#t7}ig1A_QqO2U1$XFpWtSInWNqtj)x z;n6WIwrVYt;ke=<;28O!I`h2?Y8)0~&ds9BcR$!ymp_?rgb!G{Ss9b|H5jA?H2sEFHxSCX{N3+h&*-~ zH47CSNIYw%>GXg%RsKJ);No2wo>Ph@k)Z*qOp13mVxR>QmFo~5@SZk+L44d`v8Bu@ zKn^Ae)NJP+XuPt8o+4D+4@m*8;4%Om_z!;2>=F4EL}tv4t+V*VdS~?P zkkj)UJ(xQXVt6ixvW^+&+=rBN0kM7V%oX=JbRNDpWQf{%eO>Mw_{i>6Vf-x*Dkn=@ zA2%=B*zwin>zTvzwx33?7ZT3>xMaLu*wj%P_)TEd65jS}z2O_} zwhfPt9Sd@Oby1bS3BkK&zRNNxL=gv0|Lnz--;^j5cV5Zf?XyjOGq%y=kAjl;%`f&w zACj7b7Cv3$)~76YrC!5VepTz{e5!t+%N$g91qU|&fTb<3z#sV%17_MPUMBvHaYM}~ zy{()J;gr#EKcmYAes&enQw`t4K#P_QFQXYs{wp?|1`A7cM}SZ~4UL)yzP8efkv`^L zSQE`w^^PuzqwJMd_>7qd$%cOR>pAw%LIKu1HjR20b(Ba&AH!hFX$Z2*=qP^J5c4~* zMAq#C1L;=$;&XWKo|7>G!m=#2pXAUosC<?kNW1f~Z~DcP?${}(nAvP zHK9H88~-KMjgv04&O0@~`BLPC^Ahq^jrKn(_fr~ol@@r^Rld51xUGj+AdC8O3!p0m zE+RmJsj3yCwEasyC1iC-fX15_dW$JNF!kpYJBt+;WKloOro0eDS&Vb)0A2D6Y1H@R zV;)BJgB-hpnex?2yYP~ZLsCuec8}|oC+5qVL#^1myMcH_)7MlpsHAoq-NG>^VV#;` zRl*RunvbCyYC%4puf)wO!SmoGq@;`xjOW@kvgxx?3P6oh2MaZX0a#;`JV}&e~~X5pW;!2GZ01H6e@NjUW8yH@~|;kN1c`U7_G)@R=xfhUI=Nt z^U$skKw-rvGPP@ah-p1#Swg}*oVJ)xjD^Zk-42npX2XPC3S1KNe}0`5E?;3NOPK=F z#>=vp_o;0MSMg6NgcYmC3Fx41``oN{_Mk&{RvYsvfFh#BC4}&nvSl-bT{FPE=CbrW zLbMf#BWY1<(fPG@O?$=Oyt2&pvSm5~D-Rz#qv5o)`9iJiu|Bcc(Q8`6Rp2r(r+GWn z@^-=sX_8ZO+Aqg=mtb5g#$DjpAQHv!=;vPP!f+ z_pWOh(z?@u6N5Usuy(Ke*@dvilaH0n$9Y-Mu>=qVt3CMmm~kxAibZz--HkXmKF%o* zu$#rXKzKJ8ztvIgD#SISjlp>;d!hQyBFzdW{x5q#0O7b2TrdPU538aL-*T*)e;YNi zr?qpAlbY2S!Xmz!+pclV!U<6$v|a0I!tyZmX#T_9yM_WVnh2Zbn1~YHh!5=1*fOf` zo3gjs8dKvYZza~gXy$V+zS{>2wBdPOgrI>x;%=FIg1ZJ7Q}pg|pI^Z4;-RI-seUrM{3tQ8A@ z@6er5iN*yw$&hNp3>OaJ?Z(XP|Hv@>iCpOC3Q z9TCJZ%Y0-wzpS6#S{Q)?K+1XE3JGDs9JxHlSSA++{~YM;z82><+#bQiF~SbpF|_Ml zwLY>q5nH{n3NIY?S$Zl)s>6%22yvARKg?yGBG_lP;ra6e3k{lJfu`1in~&17Fy0Hn z1<(GEiT(c$V%RXwt4HnDjv<;FU1z_gve6?AgRnJB>BIj{KF#S@`Q~Yy6zZff4SEiH zEy72j52%^8jcz~;^lh{)Rcz608!n3;uS5UPqG!&8n95H^sDJg=3K8|@qxgkPHFaF0 zbj2-@YMLW7%IlhZEfryE!{MW`7WLVsj|YS+eyY+0m8eduA2;RPUa+Y#M5ev-YQ^zc z4LaDpUgSRaXo^^9D&LkTNfJEyrIR?xd6EPV=SORsn2{%!xqsXKNMtwrw9Mhl&-_EhH;5D z9Q17HGjfzO`)fo=wxbE_qrm!Dupc2D+2opL(ZtTGH{uA9Pk|^huw7OrXU~C%e_2eP zb_%H2E)XTbWeSMs`C7YLU$)Frwn0Y_q2jU?U&t>Xj|SzIpO5@p=8~K*mzY-kqxCs> z&^aVHO;3=QBaijGO<+=?k*}IQY!h ziU}$-PnOMudzslyb_#Kcg!TA|q2g8TCOj)oJzpzmiVN?eRG##Tuil^n1r{ScY>0>`W>HV#$Rqvj=_K5^0 z5F?m>$1#~VWh2|Yk zLd(_k7lROuz^10#PVgIV_m-s035L~Z5~IJeI#922Rd|w7?O6%1x(@$0NA{s{%y*d1 zG^v6ucn+*tGm0g60v_mVZQHh}sngL}a0FY!v0#JSQ`uQ+Pn6iTsgtHv832xS0jp^` zQ0R&eh7PgYu@s4lHLD7eVFMAKbQZYs0hd1hj90RX+tckkt`+u_vq=P$OwNY1UeyGt zp<+9;e5qya3g-P4Ok#gvS++G;3YCB9U!DV3Y=^;luLv&Z(*FiP(bJvjS!GKsTJCf0 zyj293)r56V%g{o-h_~AR_Bw&xoemcWNhRo8ad4NBm=0;<&8fE*tN*;WX6g?O_vh3X zVVv{_bM3yexW!LP&{HuTCZ^fAp>FqM#;Qv^9Q&y~%H4Voq_xoQX9jBa0niFarzkQ;d za6fH5mVc4_BWd{ZykS|DyYCOvOM-4Ie=nF{@dF%12 zyAMb|p3_z?3OKCmSU~p;1$9xYiQ7cW+VyRcTGqu+2R9j7#9PMOlMnhORKJq5bxWsKmoHuEj$-dgIqSY)2i=QVpEAm+cs_b{J;uUWGO~S4ccug7VH-((je0dLZPSNWD=r*7yVjt+<3^m@mL=TPn*XB7DFD7M zzqx47v%w3G{H2$XD_uU&SWxk$%m*(Irj8reoJ3R)-^U;HTX~svrRs&GW7QH~Jti-|nEvgFecnmij0Gan@S*r=^Z?ETaXvezM?OUm8V-I;WD-RRN^xM*~;AC>_1gP!Y~ z9}IbEF@3{cdzq7iUiF``64!b`F>!IstYyeGX}E>Zu~pAF=$fNLqpfQ{Kj^2)nmYrS ziyFP;psxmb2zfIWJPhLM1n+d;{x0^u)BC<_T5yZuQZM2BeZ&3F8};n2 zq4+Hw@*Kt z7t-9>=drPbljM;KGj#A(a&S3{SGEwi>qFhNG+~bgho&=?V*A($MjX@M3dhsph8g{| z&7YNO*GrA{lxWR>9Q3{j7l9GS8(QzMr%NA>d~;G4M~`To&iLNns(5!N;KA?qqnmCy z#@!9N_J^=Wu%G&6g`YXJ&62+QS8QSgb&6Ws z9!GmW1T39EZBt>t+bV3u5xB*!Ag|x~D?J|P(?U!}C$z;zkc&)GBidKfO*nwd!ukr8 ztLdZC%GDSgyk7x$t;0fMWez>5v7b4LzK3g=JhV0s?=@!xdF)AQ@wk{hGG{v<+ia#5 zAa_0PEA2yGx1bd|IM)HTFvw;A_1ojFls_Z}4Np?n$d0*u9!Ej#$rbm<`n}d<5QBB+ zsc4(dt+5XkDBQ~UJLQLz;jcogwIy@LA{Y3s#(du*t z1{S<$Loxg#F8vCvFo%;Mo%V*Uwm$^hIJFtmtuEJlz=#r^`(ImH362wZbhoL|x1*KT zIIgunFI9+z-yP%0%|qqs-#-pn7k%l*H3>bo##?w z-k&Xqb5Ys|QG7EsUe@PIrS|c7^mUN|Q>m4=G2t#Fz9dlTS0K2EzPz~y|BX=4I7 zpxjYI?=@(tPzz{$jXmEh)8-tV^xG%(;p}_?3P)Ao?r}Wx28Pyeqwmf;%o-Jd!FgutR=(f2*V(VX zOgu#0V8Kxp0-``~UbmmU(sjgUAJ(y|%{di_^%+{Q%VS3l^xwuo=V!)}d-ECl*;&^( z!N!s~QR{lUe}5IZFU=^JKx>u+Nl`tj{RLTZdDZ%X0neS#SHWU?_kv{38$-Fcx6n|D zXqP3I>q4d7r-J4Wq}j=JVa$u(a<(W@Xe_H-&7-7VIJ`{MM+|I?Cps?0LpUQk!7*){ zKhsCqWx}UzGJ)BiYP#O&5L-fqC7(2xKMK&I1~*Wcr*R&(Ra`r_alsxZ^{Sj??GOn| z+%O(TRoDn}Vk;DwC-%Lh03|&pXq6D}bpzJ%dJs5CiAQTM@azG(4uaK=_(E%G&Ner<1es!u$ui@!j7=w+bqR>P|9R&gO0vf@jj? z94T5}7R0jgKH)js@qTVg z3!(=2m7hpBKamKQJpK8e#dFx~iB6}541YjX7J&6+4XY*y0cKI%Jtl*xBC3a7cq({Q z<+1~@W10UsuCixkHef_F0P8Sjranfy-Njt^_WnV$FKW}hR{05tego>4t|AAPnK4{( z*+6D>BdxPa=YdM*w6;>Z;-cdtueM6Bi8;}YS^z&SGfIsQgz) zFTLxbRJk-@nf$x1Y>{iX>MZr3@xNFXYHF|w-SWaR6=v4wx~>_O@edQN!R(3LW%cfN zDN#9hsIH7+QTWfbLFGDNzRX(YQi>k%X6^*`S_g2e7wCPz>~`Ca*FGGKK0W0PjrQ_u zZgx~7FH&!^7WcA7hY&HZ{q8NEHf#<4i_*BLZhUxFoSCznhYD3NwRh3P2(LM{1AE8v zXvA6*n!m*Rbq%L+;DIb2$^+bTxlfpOOoJ$6E0)>I>gh0F_@JD~Q$o2aXGldiDd`UW&Lwy1EIPBKajvAUDk4NrV0AV2H(Y)HNBH+@8j+OtPo2;le@kw3LB zOdt1Ehx)Il{p=a>0w~T8QN6_Wq3K+|fxf`^yom9n1ry@=c|^WGj^FDelq~gI(;rw^ zCA1Cs+RT1dr6=k+PmHq=5o`IdXANpQ=N>mbJs}X;4`5rVY#O?^?I@b_DDC3|xOJZ^ z_6A}d14IslK6?Ora?Je4gvdjH^%zxomf%o#l(I$Um51|};e11}bhhf%qNW4cFbdf^ zjOtnAWXyfD3r1A4^*DIx2~mP%w)*CkTeb3?zqk%YfB+7of`H2 zfUc)n29DE1KG{aHuTe4HE13o1X?13Dm z8MKFLMDCE;rAFmaszPd2o}oB+7UTXC)&q8kCbuw6BiL6ecjE)kTose4atCtwo%>xE zxw_Oi{L&-h2DG^|u}ivDZrvg$=XSTAxk-;j2hq*g=OtzsGK0pN)CrMWFVJ{C)w``) zOknl?*I@*+2I)K&vU9X|^sXp+v1!L@)1A8szqQ^H;ghcbciU5w6Z#gcit~xbnR`-i zb+2CNs}A+!m=K+uL z5nnrO3lmEhs(+rqdFA4KXOs?BWGx>{X~%dlMP%tw%^j7;fXdE5QPoBgR`-i(^t7$@vj^JV>;8V>&&y;8nY?U~gd zp~pwQ?(?&msp`H9dUfY}^#9PuvSd#OSnWfx8*JugSj1pg!JjdVVc(R8BF|D4YkaUN z8oR+Laox($jf-Hdh0aTJz!+%8EH*4g)z_jN@w`y<%;Uj@O)N*1eY<||AhPDM>x6&W zW9MO)Mv@0cROu>#nWWaSqZCMF2binhn#iwAt~KVBu2@dDnbF8cUD79><8`9n=NJHo zu_FNv&GyK{w%M8KHq4n9Cc0h@wKznDQ>MiIZG zSXI86nVU=>lhDVM=cr|Fc_Mh$#4HlAc-}5jHE&ksE*1OCVq6>@I$7>MuiXCftY+4@ zG1lBOoNjV%$6ii$PZ4}MUV`=cS%jXKG>Z~$$fLHKS?$364yym$!QK@WVGRwvr+@1& z9KHT@Y5=0%{N~v4^Q(K~)i;jk+?-H(cUSioE#~DlgwUwwb9YTkMh#*C*HJvsahJ-y zH=-LIc%g91eNOvU@{{8OqYBCeggdL_3-|v&iq15yiL33ylgTm@60$&m00Ab1 zRkpC%fGm>)0Rc7ch+@Mc;@Sv^6j5iA2mt{DqD87UfD10wh)dC0n*d^PjY=)Gwhf4i zmfF->MN4_!dB5_i>{{d$U%yqw#mW;U6g_}8+SDZrpVcYERvEBw; z9ePmawb(x(MvXkt{aP|xa0KOGvUNMu96R3I(CgDf{$HL(G>!*Z`=0(Hk7I&DJMoH9 z6E@XAre?E*6u2QR>^3;_v@BF&4&#%f63ou^>d>;@KgwjR9-yWns$$YiD3(8By$v?+ z(BT{2{P`h-!de0nmzAL60iXB1x9WOl>d6v4;h-K2)g66tu`lk4Jh8Q}bj0Lzq}PuL z@)mH~;PuYRCJvH4G+nKC8Z&7e-8q1iW0F>lVF-G6kA&oW8}b@9JMqj;WuQ-Mt7EQS z6@|~gJtweLHUkMPNh0x{y!IG8_KHICngSTZMF(D*ym=%p*2^y5SIZ?acq9QcLT$A5Qq-=nX;^{ec4(4{pA_fz0tqo9Vh$005^MV=)7-~ zy6Q2w3Wdus=>{pUrN9G|x+YNg7s_2RfXT~nEktRjO~1UwC{=({&m9MS9v^tj8d}M2 zk#cN6QqDs+aS(f0#F_*m|GK#Bw25WOaRt=jaKrs39Y>p)1WFMA47?!qDwB#<$`P$> zN0$;&4)Ui6?QW`0)e~=_49wui(*J2balf7F4C$A+jH(Z+83qZ z%;CN;EqSJukp1>b9dAH2EkJ-AG;RrV~uG_@um$HILoFHlD4H9>9SMt-%l^Nlm zilxjRQ$d`BF$Vb9q=E}3`ScBpHH13$V85ER>&`nfTS#!Z)G=!#ae#4G0)%SEK?%?n z1FW|znK53RdMP{9%qlbi{w&W;Dwran$*h>wXV?7RF!UcuJ$k_1 zY7SH3p2*Bo!g9|iggqw1Q_@@bkWv)=v-ed~L(!6W^WxwSXWT3B?pCuRd-pdh@0|48 z9=lfRT5o1Pk#?VX=T$)pKGvUsk6t~Xvcb%HI|Qq+VIwKMrPD)ecGb-g z6`Ng^xsNTT>B^bA_5mc8-L$Ig`pIjix>1a;eD}=ToUr4cm-8MhiB%J475%LjSpkcN zYE8!{NIJWTFtmFGncX(ji8lPW_-^lKAUNacjQ~rpcZxbxOyX9_1X_8-gjtYHY;tNR ziA==*9T}4a!+oWPZV>w0aO|>t*|N%K#OSEZM@-sSE)%Ojq(>cq^~C~v5EDTjaL@{; z{u|t{Bs!{(&&MY?=@etQ_%ewMx zvRd$2utKCbU$}YnZC&$Mv67EjQAg)2nz68CQnIjp!6{g&UqGQdwPy7tx}mDjj=Z7Ww!k|_NQ0lUh`$q%I0N1$b# z1>H3d8og8chlC6JYWU) zjJ(;?$&7u?mo zjx|Nz&{CSCJr|U6Bfg)Cx*h*r%fnlr{5BPJclOX!^quJMCl=0|Gd}eoGv)cgrp(!w z7mFOj&mY{}5aST(*D&`dw6oqzW=FN$)bxUes7L9sj$)HhRvq@*Myd&WZYRu_N|SPz z3u+fKyvxSTyM2~c}Ik+ol9bdRhvk)xi{>-v^BAU{CZ)duDwCjDCyhdT`7rN zS>HcyhDGnhqpX%GTPJsZ4*luwS@o7@Er}H0-aX4OxPUY5_8JM)&$_#IcI!MN$zMOe z{{;2R-1x|&`*PL|9{nOOr+xPR`7xxUUh{JP8rwfl4HN3pEV|zp^Ahr8`Xz9ahUfR% zC~7p;aXM&{lyjY|_*XE&lhwOYUZS*-r~9Pw^pySCc3{oE71e7C`HD?HQCpf!!WV5D z`)ddH=B4COcnb=-t6dkQY%V0yyu{aAl0Fjl%B*?8ROh(|t%YnUaKEp1t~y)Bj*M9# zV+D$r-#nEiUap4Y%5+k?w0=?=7BFe6+9R00R6tMqnC)L{9ncs2A{rpr*xXX_2BN2| z=xB_)YO8OQ=(2sQg>Sk-8*er%!*l|t7xmLnQnr3$_s{!PUhwb@U$92l0*^xRK@a2Ht3&QiB8_gCa+kM>05yhjcbf@};NkEe zALX6|?}88&_E-#e?M9)<$<+j_BV%h#L=3QrAusmk_S$PgVJUr8kHFhv)g@ZYaT~>Q z_jf8Nx~Cl~68E}K6UEObW-dCR_qfTaj+z9RMVvmL+sf+H(|OG0*>ZNTjTY8pW~`ee z^Y5ax@Pj1ga^wO3#V)(6B(N4ctX zDTi$->9VhGR9kOwnX2Z#HQ zW-QOqcsOjXkr{-JwZl9Q6G~h7xryv#H@RNW)50&TV=T3R%yJ{y+eGNq#tEL$Ku}XC z90JG~DunEa0kkB$bPI-Qg%xI|82ORWasr>+I3Hdu=@mTZv_)M2=5I`E^|0n%> zgHpw^1xz9@nD?Bpm=crzxw=taBajwG;nvO5X7Q>fDM9l#k~w}OEhw11TMHLMMynQK=BW$}n`J;ZpwcIAbA zexd`r=DidzS};L!IFY_OcdnyWEJ*?9M0|y5=YkE ze2&6fjAWk{8|C6#{N&k@`3S8RpBY(JNF@3YcBZHvoLRM4&NnG& zt}l`FTq5@VCPE7vPV0F;qfhYKIw!HEnY*?}E&}J=jK^yiY{)p`M{=o6p68_4yCJ3B zFSs`S*opZy31zJ;$AI`q;$Ql)Aoxf0MA%@7<1VE(^m1A7qKz{gmnr*PiF_w^m6A~B z$^=_cY6ymsA5Y)L*h_75-xn&Oh}I(ZQsA@;p9J^(d)xi+Ph@XxCb|DU#<{JhMvntO zx{v3w?kcKZJ;|dJzVS=XjkM^u2t&g$LrRk_PNIfa7prsP+iO-%=CMn&+nhJ;Zdgeq zcx+Z_MblW6;%$fDGt)q9`DQ{u@znM0WPjY4gF_zt)VW)KJv_#>{w#fqo0>yi2d$|| zA_|CSIVlJzt(!+&Qc`SkuU-?y|KPjxnI?_XH6yV(kyyA*vy+2ud?d%rW%la!8c_GNq0X2Uo#nR)( zR|;(Q^BzMZjISdZ#dS)p$S8&PQEQfPd<4Gv z6#?@jHNv5Gcy4wtyUHMs0wUovP^6eaIaGSoHjbAXwR zuFRQ$W^M?v$>Dm8VFBO^^(%V}Q~OWfNWun{%9?T$G(CpaAIx&>p2*1ZV?3B$ej6BE zrD8k+;O>n~LR)Fea-71T*CvN62%}e-ajY=C+VIO12EVqkwSEs@n-;u2%}u9sA05$; z9?@I-=l*&m>!#lWJ0)wZf9_QAJi*I7WBqzW>U7t7Z$^V6<3#58>4mH2EF3aq2KhT$ z(M1#1p-aG_(T*f7UiPbV!nTQ!^v&sN%*u|{EmRBg_-~wtnc0i`2Tr@#ZD8n-Q zZzUF~GJIs~(PTo-#6~)Q#QCkwSxBWjU2x_;${(!l((}sN%mY2MoYr&4Z6>xy-ivW1y?+<`8pa|>jjl`3>{hh==oD)9$$tB^h)_FI|`e_qlue6jPu1f zZ?Vut(eC_Ulk-QwftdbR3ps+_E?GiXnwb}G{zHt@4IRejUguQld9j=h=F!h|u;b0_ z7NpRr_uN}zHde{0MrSHtFpEQxPf>WnpQS}u?+g)BN)|YBM{Qz{&!sCg)w~I4ogu1h z8VKJ>NcxnlQG}%m-^|y2j@_84L#~F3;Hq2jgQw-H&)^5gS&g?)ovF9X#tiMFt^?Yl zYQVD?t-?m-ZH7B65ta3j*2K2-z5fWZ-q=_jGaxI%EG%VCs-bBmU27r+w~0Yj=3Nw4 zBg`a>bq%A|0S(0nZ45{nQ*OCh6M7xEaTjAaHLn;c)89D6o6682S-61E@%7L0IACtWZxO$hUiGdJ6hGjUY%6yT03 zS+3?(-Pps|s2&t@z1!)(P!zT!dU$o#l$OgU#5KA*55l8vU ziL8v^2R89T2jP-7`@-Vdp-ZS^hLV!{jGSOid= z(v$HdK_SkTk$BZ7x#J7|Op*SnO%3?}0JlbVyp&aSpHQ`|RdzDb5Xq|73~*%YC;;zQ zZfP`d3C!JYsHg*`WB1xvuMqY@!vHaC;{CLz+D5ca*a`Ta9@eJk8N42npBZOADRQk_ z&)bjjrW}k5rJsq37RMW2zSOn2WRrks(~vx?)jOL_qO4knd|Pj0g%>Ecae;ZvzgayD_$cpZ{i630cB+)YnI5N3LyI4?ua83$iA&1JwR_|Ae2;phAO!6FqK(hi_SzC7+|F?yMbgv6$JnNrD_{U?uKG z8*!=VcFD%s8Tl@>5odcfYZ6^LSP(6E1C5D3zXn`*i?U1JLzNe4>k#H!;5oL5Wdp>2 zV8auJe-t8QuYrLI=^jG2F%%v*aLYL~bpE$>rWaEN{*RGv*$y7*vVi82fQ|uYgTenR zjAz{!^bsLkk-P*c-%i^?Y3Jn&`B{0cGo_p-XXsDv{c+I1jxmpI$m294s&`7xTZB`8 zlAT35eiv<5ZXt%0IpD}NQ|u&qsbrO0emjnYc^!`ytfAKf75me^)tR*8?8%CMZtD!o zn$!{ZYGUgOSs5bQ>7&GxgHL}3&s)Pj8OUneO-&NL9R*I>Y~uCEnlOi`YTKqX5oiDJ z@VGh>v$1q2WR)?MO7&usWCPl;7Q?5NY=q3YjV`a=NY&Y(Vw71&ykeD!{he^8jk>g^ ztptc_E)Bs<@U)WkphR;$M%e>U$`Mwffx6Xn^DdHVOJ{(!PpZ_+BNK4XX-BZ#u>wt6 zL`@@P(tF2rr1tnp9TDNXNQ* zoKr@K_wL$Pv7E9zr`x3PN|INjl%tl;|9};G{b)2jj+37 zbU`zH-96Tr>F*uq;Hx&)xQRhDNUuDD)xW!vomj?O@Q{Wzy%4J!zQlU<;C$KffGhin z{f~ULi>dx$$$Btl`VP?_c21EBQrd~N%G+`C>0<{6z4`1sAwNdwtZR2&+|FC<;4-h> zwLFTimIlX+@R*nA@34J@>*nTUnd}xkY2dLQb)A{@h4W4B=_fDdy>M=`9eTZuGx7C{ zcr$EkWnqo1dfQvwt@p(fB|DJMoA2=s?1GCGG1v>&J*v0{19b|8j{JN0215Jjw4g~3 z%?ph!7Qy;^mrL%Q-|?I=07VCOQtOcIXAHCrhQq_lp_>uVN<++H9@K0oxo3l0mcsc4 zc$0y!3(%O_h8t<@7A33EHrR+l^$Va4=rm^sHdR#)O&CJ8hM4gZCKkr^SpYr9m{$>o z$uLlfE|sFWN&mvF#j}1G+WDQ5ZOn4C*p7}^h(qC^q03R(&yFc%l!w;cHvJ;cBf7;$ z-G%yT6rM&K$57fUnE7~+@y@C7CL+@U%z^KiJ50&CLY6z}{N*F`@BC!$f83vlZR_plq8%v@6fn^z89_gqv92lAn zjoO$|*iQ#)!~b;qkL$uVM^c0IAh?7$#7I(S&{A8E0jHMU=o5F-(F@N^1ct?X_4wrJ=pQ2Vyw|wnaXrw*=NyygTwKCl<^Gw|@AQwRYd=4O1pGr#W?@er=_x|yyjQ#wd{Vgg05Aue3DDR&ztr> z)lEBzlI!6QLciMdrp-I1x(fW}5#wtMuh}9~2`hJsuD@w-A>Aucsve%|=xpR5sU>4w z^yAff-|$ye!Oc%0!QN!ukY&LbgB`g>&yVX6(RX{q=lAP6$FkUhcpLNnNJsS#o{?2L z_l_Hb+_>_0L4pg~dr0=myMS-WvAvzM4v!07j)FJ#M1)k@Ew+O6J+A&3=YVsc#XY*T zv8B~4y32}?Up+YTpfBxUM8VzR{hiluIe6z-o~(V?*U%Qdar#12fAO8LnJaz&`|D}3 z-a9q_;QuZ!qFD|_+`ZGaJUZ~!9^uB4-5+GB=^n!pZR;QyaJwxc|MimtrN^7vQm5bj z>6bC~oA1_8f}f8CW&6MjF3^2u0fQB8)Y~rz{Q;MHh z;;SfaeF7lVveMwlA$PZiyaBgQ+eWAy$+)ne8&YfpS&2+)ze5SrEJsSm&uKY%K(|(q zQq`pK#_GASS2!igf~agxw`O$P&5&g5!h_5ZtiG>S^kYl!VA%Bw3-5+jRrxXlzmjzC zM$YsvbP#=44AzBch)~9WZP+N}b=ZEc*$Iairg7xV=>_+>D5YKQid))K1ZzGgDqBQR zDb2BfM@o^0RNW_Y@h45rf-OOt)A;3Y|C%i*oBS)SATyb=r4!r~ZsOIW*IDrsCiVRIp!u$kZKjzLdt4 zrN603wqejLLJe_6L5roxXiC`4@qvNrO$6i=+V8kp6xG5qQdR;-CYW%piEC+R_+0Jt z$SQ+UyQGdG@%lzzTu0#sr8FZ}^@d5<^Rk^QJgM;G91(e5vJ)ViFHR$B7A|V8^G%K~ zhY%e7LD@3b>cqFjD~hf8q(`f zW_J}4g|o;>oQ0XlvA0bOz7^*;TEB}Vm}Rt!JYkZ+X`v0w3(;i$l-M(22=vnEt_M)s zDLzlyof|4P1bj1lhhZ_1 z3sZY_C@;GTm>sA#s#^$hu@sOAh=K0t(_>UwPw$rEgGJhH;s|&~Y@7RC4kV(O;ZzgJ zOD(1e?ftM^M@HT%9^`rFvA0Rc!gpi(#M{!NHy1uk{Kt9^0fe;p^pyf*EHUejz^ec{ zqYi-1JVT*my_^#d?K}U+5WJm2|3zzRawP!FiZ1dm*U*}{E(1K-d-CTx!zXNFE;VHm zXVmE_NmFl+hZf&+(5h%Yh3yb-m9t;jb|EU<@l&IedCdg=T=j!WYVUJuvr!VWM(Hj5 z6eJ;%HeawLdzJEWwoSamX50rCEtf*mbwI$Mzg&l6R8ps_hFv~JU}I#3*#4#EQpe>InR~BQBZh^HXrnrlK632**GeLOtXK1ICoSr9(60C@tFFkB5^#)@ zBFv{W`t|2~5!1X-nW&sn4n&$$3W&6 znxP`Il(@`1a~KEaMR)CHmvg?1LPp@9ewMSZ5i#`_q|pigb+SV>6d_TpeN^`4&pH!% z5B9S7-xaHZYNmgL6NgCb?KX{wm~L0VGB$hdE|FEp+N3XVId3DV=*^6@R**MrqQ#A= z9Oqpabv=dU26Fs$ikCR6N}+acwn71XoV|^sMLr@r->m9KEgs~rZi;Ei^+K2^$ocqk zp;Sv=)Yj}n9BsZfqAL&1p0IRMFH;LZ1FSotWu^=#-uqbp+vZ@r zxSu_sr}PJyRb>j$ON)`*>aSDF5Vc5S0=?{Ej8uE9+iRz49{@M^Hy8x{7}>MuzF#c} z3fsi=IZr^h=6d=rsHx z_$&(0L{8fFT>9NS`(#M97rzGD2B`CtOKVrCDmnMoruwa9FpfmW&t&9hUVwIul?Tf|CAL?k6v zPo3NGai^91 z=pWED@JX7OoNI@|5Xz_ajT#d@(@se3pl>k8ZZRXz2&uUam+cPUS&1VgW1E;!sbU;L zSlum*Gn~vy<=QdC;gRa|x&ivhUg%A#!XU0rF+nA0aGN4FQ@Os1Ls3hiMuXF(C1Elf zb*_PuCSNHP!*izXV6Q znARUtPd(zS%G0TB&mfPP1}2Lk73x`ll<4830LT z2cEeXp{}JEIZp+xLnsnEw4ifgyfSzDnt(_Bi z3dnBaz~7;)F@)Jz6!Nu^UZwhs{G-Ep+yB>kM>&jxSms&4;WEa$sw}>Mz%AXB4keUE zh)f?WKlPCC>p7(&Fm9wAGC`|VJi<7;%d$!hWbb$Yi}x*3`$9!3sM1QQLd2bsC1>^V$41Cm{}lav9>P10$!snXMQqU1FZ)VqxAsyBesC^~hjfv&)M&{6 zdhs?fyyOa01W?onEgqx%BZAc(`_{+O#7$0f(8U`{X;Q=~djt$OQd9sGZ(Wl0rfk7` zGSNg05ISYr$#WD?D2MFHYg&{>N&K;VhInVGDO-@fECs8Ry`-Ns(4{9@q-M!>HFYelHBVXJW~{Y>>%7N^l~ad7vp;D5Em?j)V{Yr&v)p^ zPf>Xno}lHcHeSn_Ew)h;R6dnJMTybx9ErTFnd+k_Z6DIiGJKJwXx?g|{`3BX(~Q6w z22gDE&NFc(Rw}9jcIv6CIkA0v$q9C9GN&TfQ1Nvnl!1|#s=Tr4@GL1m)F|R{^l&1pL$w+yOW^ zK7q?osKN%10gTTU(Y7xj%S|9|g4S6b2=(;!9H`ttxuU!jXQQwpLrvnYKpK=pV^`4phcRV0S=IiW_%y8Bxzfz(7{n;eO%G(ah)W2Kzv zPjSi`*V3B{Hxhkn3Cde^9+c|IpBk1d6zxa_mY1MG8Ni+;k^c3IVgs7Cgdi?5V_sRa zQ;ID)9En4xZJPR{lnlf98A#LZMG@b`-}ulnzstoX&Ok|cLO*GO3a+pFw2#vFh^Q3q z&9PEF%b-FVoUbfHBGGG=luYFG+6R5JZC^&bH_b(wqV?rsgto;wVctTQ1Ov!7fj(Hp z@XgRzD_D}YE!3DBo#`K?bR=Zak@m_h3&@EEa*nM#(r%i2e&Hq*E;3QK8EN?_EzJ%Q zuIOJ@3)YCCO-6D$=C=xk-9Cn%DlK`xz=tcK-~J-V=K{?XP}qpjTU12S9yInhY96Q*}jLO;Uh)06$^5ne(93LSpgGqXQGI6yR`1e!*bZoeF* z=ApDyV_;6N|7s%`_YBIjn*yc&+xtO2N)C{&#(x6ErqD!)w%(TWiHobZu(Iq4=z)PT zHnN+I`d2wrbADS(Hfgnz@@s5a<^sxU>(N!+lpqz&p}yp-k#Lw> zvxIBZ1~(Y#UmNJ9cETx27?QpaO;aY!|3fU;2Eg%+j5(*|0vNZjanf^&SaNjPxwtuzRaf zvLrI-fBiai1!d?E`3Dd37RA<@f}aN`$dYAH4rfoYjXGLEOAO8O{o)g;5}Iq^&I73L ztNbF7yj?s>q76t8LkTKU{I{O*diT|1WOx0YnFjE)yJWTkOa^E~zF@vMJK6x|amced zG|v>uETxkS@lUZ)p?241jDbxfU#=wV3|914oU(ZATkkT;V}IIvImP4|@syzFY)}dl z)>vao(5js%&1ocf)*E_*qN>0IcmCx9Em-`Bc6IwF`&KI+tW(MY?(7U0{YLiI_mln} zq+=qLk`m}7^mvo9-bgFd-yJzWwn0_*e(|nc8pS&unt?ru) z5>R%52GreMWhFf53m4ZcOIJ!_^`v|bx#m8#_Wb7wV%g`$@xPB#LhUbONmuig5%~Z- z-V5pH+ysJQsfnESda3^w=xaO2fWmR6?ez%E{!3bAqO>m7;T+c#6D0|sU)4c*>m3@( zwXNTGGgRd^9p~&6N21l0&;Kcra!3iS#-K90wEV~LG89%RXfuR!G&cHWBXrOdba$&y zk&%{Wf=-yg!A0abdhpM}t>qktc0F|1NYPs9->bp7cKTT?=KZ$s)u_X53#3)8KZDXQ zSfPVrm|vcSq3|(P)|a-DV*vbOlALLCx$;e7;&s=!AC{zF5WVS_x+P`qd*b5LU20_FCE@h2f)$L{&= z#G-K3nPjXk)Jg_|^*8|A*bFH=v6S{?yh2Q%kRL(#xz zix%l5tEoq>QY5yUI~C9jjx+me)Dm*kvCWjNHgffcqSL>VSqg9!z?oIv0{I!4i>)k{+?Z3aGp*07iN(FR>w{`4g zadG`I52UGC2cneyp|u4?w>iEk15do7*VOw&Zv~^TENbO;r0QIyIYY$@3R3kR()`<9 zO;2w$2P8kgJs#QFSMz7f^7f7!?%vAhy1zz`r?I`!eBfg5!gK)=H2r#N(ki0u_6aWo z_MFd8XST-gn3Tw?>s_Ps%f>m~=jQUg&KAHgqvSe*r`~yWJSz8wgCQy3G`XOESA%!# zx(j@H$}X12sv+QjJh0Kt`*~~o?g%Y1$_N=%jKC4E?N4>lWBLN7tPG{>nRTT{v&W;( z8lsDToi4rRsx3C}5{}wMy2vs>*0{TxGg{+{>Ev{tBD=jKqSz>^kGng*zTQJUsb(cE z4eH#@FXBI~@mY*LW%y57JL}UuNdEd+;M%dhLQ13>4zIC4-R;n9L~A38uufe^zNv4I z%EC=;LBSE?e4roJ>`5Hs7doLItw(B;-VGP%rqm08d4Ywsz|LMq#ziMKfv6Nj!5$*eI;iiRi*j58IRC-oa{?@4%M!YF)3&Ocm|! z>eP8O8-Cm+nCubN``OZcb$*xTstGtiCu$V5ruXK>znLSDw%t z4Z`caP(&8zT`?)0u9;O~!fFh^UF@jlwL z(8u(x?jaAvJ-5qXq}rZee}B3{D2`b7$%HtC^VjX}yS{n#-Gjb+H*l}>|E+I|71YaN zQAN-)M$|j2i~;=;B;Oqg92sH#R?(VQ{ZqA4T2r;UUdZybiRrX>l~+g2M1@6GAK5L* zhodU2y?a6pmrj-QD@Z%UwmmLa?8JzOi4mg1-R#b^kfG-ep*r)>LM6^SJqAUaq;5CG z)8u$!nmhBOhEuGlKKWs!bTz+spivAjy-(T@kuIfbIpnj$&!FO4D03aENhGKyN|f5^ zHK)e41z1*zo88Kc=bZinsJX-hl3jJN-rSyx8xW?Xd6%0Jb<8WKbLxR_<65Nb_}v7C z253b*TGd)$HZ_%c)T6&0EHI#}?8Lp>CgkB=awO$`U=pTdSMs%o-*JLPZ*jK9M(&UZ z1Efbfh$(+9k}q{Q7%S!M%Np{oB(%UbC3DAPf*nOq3Ux%pJF@ht&aR0kWFK+J4?b^I z+T2PO_Z~I?h~Gv5dYLrA^Y!5b^`bslrwh`PRgAvjW_kfLcOu!-T(z)ot4+yenh{ zcgZsqyf6P~m9o<8q#rL4w(1roGgR5&ey#JXlK+IvTmtodqOAUAHy(c;j*@&|gYj&X@DWVlM`V1f@Qcy6-?~ zA^p=Ps?5&dw2amGpEhhs7SC%~qu&+tpMjC9m%3|JHL3l=P0u{~ovt$7ypCYB=ycNV zwFV98FJwNTi6ezed315lUxF0Zi7INSN1tvX4c@6=}P z55PWX1JDN$VbQyIg+mb(V$i%NyPfu%pm>XrJ}+H;TWzAazLc`oa9(;{`9X_Va5%%V zgrhcSuTCl9^N%!(zE=pmBu)M=FX%WLO@=RXV_q-q?1_gb?($D#ZH)8kUS`>)lFK(> z-qnC_cW|l&WnH<78K8i$Z)}W;f#wuTrS#OWB=kaJ=x}TDA(xuS^o;3(*#Ou1l@St# zwzIdDQ=F=-Aa}bgBvVWYyC5VJyf(kzu8y2E zQ0M)z;fKX7yTVVK9OWjd`)lK&yPE=gyI+#ru_bEnNdn%;MBIv;X#$?qF;z!Wd}rDf zM(KreD?WAYu`W%W%7c}5vX7R?`us&o8~pEg!4;LZ3dHL-FObiDq^Erk4>HmLw@X*w zK(F2BF-|r$P2V(mMChfCr{~H=tb&A}assw?Z_j%D+VE|ZZ3Z)8@}tM&t<}k0LO4mG z;rF7X7xzujrUSU!6^s-yWgvif+^t4j8cmUwiV zE8yhlx`rw~_%F-b6}*-LQ(hIdGlLb;SQw`9=BGg}dMPIg>2;klfw?_G&N?Mt@-=oN z48uKg#^MoOKTYwG^uB=DO56vP;k}L z058;seNJ~m3`TU3g=El|14=9;{~(QWt@p)VQ~JyU-7g$v4tD1 z*cZ}=i^W9I&mmGp&OQJCel3ZuE!yT8#Zkvjes>5Y*A}0dWvfH=8h1d$PbZ7J93zI+ zH>@*+lVEKt*+*q4+;(VU!Jde0)XRp~%mO2IjVDRu=wVQG0RQ2m<_W1GHJx%Vy?(ek zEJ`9evpEi_$cSon2+Jn>6bIyb?j0Q`h1t|XJ&A81jcp@|C7M|}|1;4VNe?+5*8cou z9sK~=l`wQp4u+c=!<^Met?H;r_3LqUm?}ctL-w*8QzV`q7L9K@$+;VgT}1Xas@WAJ zVHfE^7t65~aNxB6ro_qFxI5tlEXK*`xw{>tf2UD*l%xK!slkmn9H97H4cP@z^CWxt zFeP3=@+iR_Fh_?mc<`;dJg%|O`KwnRH$~%L(@wI4*jGM;*&&GOW>?4l&EgwD&v7iI zixk2+@CPABM}?1mX91@$UPZWlIcZj=r%M-bKSub0&vor1J6cGhVFwYaUf`sa0%{aE z7GT_$2vFv5{j)qNiwx9d6}vJH@TA3RZ_9;wU1WV8=(;^pnFGF@bWKu$HE|ac9_~$x z`(mkIE7nUY^|-X+v2N4nPJedqkI-v8+en|1zjzI*H?9lpZ0RmMiMKv8`2YxIc;2ypI9*4 z6#>&)G)d{8zY>o}eUhyglV_hqB;;ry@&C=ggdoUHuD9WT_vxH^fcpX;FAL^|X@(GW zVK2!?5$~Zt6Ww~oyBL3?PD!SK6C|>ufiRs~ytBQHf2n<2xdFeE94ASUL8LPuQZ9F6 z?b`FNCxI2Suc@cVGk-Q7&Ac*m_*B|lhpTq)ciTuYu%}BP+>hVw`=%vSJ*-i;9 z;7@&V5msWs51eEcbj_GAq_g2c!sZTLA?I^2!1f_ zW?I2}qXJ#8$Ix5YU61Q?#j&WbrH-ip&;^g6@Tvc7)s2 z4v(*jEuf!L=zfC4#BfM-R%A$SZpB&r*oInz@vp{&F8me~ui5VDWykhw@b4T+E=Iu9 zKw{`0-@9{X!!{CEd~d&!TS$HIDw!@Scb2wyrN!N52Ict4=)h<(R+IlU*&X zD;I$ZJ7bybs|X1h=}C`%(-9+Ys8-@`3Y=}j>8;qvIO*Kp%Px{7`zqwoU2TR1Kkt8t zdj)Aw&e!3KK*CAl^?<||-zeNomPmv?R&4s|idMUArB@L6;vgQ=iW92l2$-BQc_j`b z5od@>-gWYz`nm5{m+Yr$kMYY3H6A_ckXF*rRWPJ$j=PmjQiDkPHSr|AHu6x(<#Ay%Ntx4sO*94R20F*Ej{@_59AV~dM(H3QDY&PKcKVoNt z$KE^RiyZL>p*gxPjJ}@aJ}h)Ks_#8x!Mw1apPY5=Ko}~3Jqa-EpsNw^o?OXL;I7sb z-=4?TTX?8K6QP_Y2`65R@IbzW6#D~-=a5W0`6GKw7m;ZrmI`|z$^H`bhBdVg*-cHF zD1};J1%f8=2!7`*h1yLK=B=bOzs5Fw8R8*9Zc)g6xwC*;5@eh2(Sthxpa3I*OGqKM zYpy*c&sOZWUR}5X53ypk@|u-^HYZyWknUHVcRaO$C+&)njuWb8Fth!l2+>BZS7(?n z3QVzdzkQqNV4_W1W+vq>2j?r*X@jJ`W}eibOj7NU=#Q>m&X4}c_~*Zr*x6q-@YOCB zJTMAe-+tWR&e%-`cPuN7aPRUEW^(Esgp`zNIt zff5xxL41w~3HpuJk_QuRR;@eYR@tmuAMBPs<+^0*xSL)dS%LGo9lPgvJvr~T04QDsMH3FmRj2eL`20lC~axWOW*ll7ndLSzz6fplY7p6 z&hNO&1>su1`a#fjOw|*y@AJ#5sa3^?;#Iz9$nt6RXGV`zf3w207u@h=bEoVs2C*Uo zR8#Bb+;=pFTCPNipx?1O*gJKYf_>~4Zc5Z;r_Pp_7Rt%2-Lbhr{QK%TR_NXn^!-&Q zZ$+!Wv9D*y(>pXEU!gu4eDRei^z1+C=FYd(U#g=l)XZEhzjc`MY06;u;yysjc#K!w z$o^m7vizEk4UO~1o@oKsGhP)`xPh`ARJn=@f#dmq6{P$GKx#P!8GzslaOyDVp(ytb z2u|Zr9uJ;`G?6bp&Ee7#$UwOljFshrix@Z_T1P(TyA@L_KPM%ItD;VZO{{m4>&1(< z*tV-ZKYcaxN`=Rf2SpA3sU<$0v*JmWpMsV8%l+!1$6?u@{~Uij^r|X(%=58+)~boe zv5G^sZKMDPW})xdp`?o>$ISr#26X{jDMvZwsR!X`V82MS4yaykW;2XiKM}jY^|}w* zvq;`K$D*DmI_poU{20!(Wv1K5s*aCV#RKexZJmCYv+C~L*fA$x9khAFFP~rftM(Gq4l{AE{+=(}?kd-(jmn1MwTnsh_uBA|M$%(J2M?DCsGy|0F1 zX|-gE6%DW$U{(Lb*iA8bkxo&I{^(~W0bA_(vOcwmnIO= zFZ+kOKS!5lGnMHgtXn74J1f?-6o8(nDhAK;%BJRaR=+$vXO^D$grds&-GUYDk$-P_ z+A@ukTY^y*J~Kj%yaK1xr+9jPY}+OGDHd=Sh#)~>FYAawB>~j0g4B#FDzwvM?i4eG zzn8n+`3uf5x6ZTPa?qpE&NCadMt@^|%aZ@0GDqHt!B`Or(6KXZF4-3$vJLr0eK>9L zV-~dIsmd@O6FR7NAN$S6nzE=EBBSS+xD62-Z)jb2oV(v_@0zDD3rFs|#3YOafO)r< zHiyeJB^QQzmNdl7*PM=xY6qXoNEEV*;nn(Q`w#)vH@|BseMrYg2gjfKGvVgOV-X$&_D= z()zNyHdP#7b(9@@wI~&fQP+!HThA}|T$X-xy!6U>8;5}NFvy?;KEi#aV|@{ceCddM zFkX_ z9?gW>K6OB<;S*SHqHR)J?rlM`Rf33u%`GxJSBrtgoi`Ez3Bsq%-Tde@>mzw&xqq_2 zw-PihPVDnagyV{b)J_Xb5p2PSf9tEoeik`PRD?T0Fq9baT ztZud5;U)D+e&cPL8bV~jz97Vl7^l@vkDmBLXhGjl0?(VqHN1WA3)N2Rd-a2E2BR57 z*ECIJIIpiLy@RbUH_Es{V^;NNn`Y498c}pr2aA9F9~|bTw41HiB3hCxczGh?j>q2m zjB)o(1&q&BF&%*ji+a7-uKDek)vlL058sIk*pavUH?~`2Gq% zoBo4yC*N@bay{e5skPf!kek4ZlGvwUt2G&vT^voTc?$fYioW3Y=M#au<3y5O;DCf- zwN^;gZX44+PH~KqW68Wo3{OQR81u~?FR~EMp&eH6c5FcgJThlm1PSM0(nPMTPu_i1 zB#uaQmWhOO`C%OrHb2xfA>_`sJ9hrjd~~|n_Szju0>`Z#H+-}6roM~2nY6zx-cc_V z3D+lHyzM^o@)yRMztge5$0;P!%!0IctX-+!p0ldT=1C-3w%`8<%=zy=~zMQAx zpESFAjw<0b#~GtL(Y0%eiEEBo(51?*_X8^GeCw$c+lN28>Ljh65-*!1jOE3mK6Zdqg=13w0b3S!5S#Bq7NM>!3F1)$TZ}hw7GSKxY-YvrKQxLBTWBFn!A?Vq zQ&H1LD!InbfoRLU-17i7D^QC5a{=-kqi1_JhN3f4xs)q*f}F^3 zc&XXwTA+YqIxI*C+J$AAInfL%BG^D9*??;0Nf2H$2)b8FYTRzO6y{hH1UC&3Hc7w) zO$l^(jr%x}z0fRUs*Y7(H@#ZF@g;aOQwh2k0PvE`1Xqn2n#&|JhL^}h-iK!;K2q>x zO6fN`W$V{Q!e$L1b-reaW&Q%3+jiHr+)y)jdWg@w#)=dM$Tk|_<@R=cQ>;Wj^-4!Q znOR?pheS@1yoBs7QI-LUXzxkcLQ;YXhWl7gA;D9|z~Xyn~FiWLZJz}jHd+S zD$xUc&irO0cQ3#eex*db8o}9fX0n%xep713GJN6nDN@^Vv-(Vx7aA*4zb-UDKE*@4 zMAGQe4ybnbTcEj3M)!7{-HKAkrU2PwZ_!+*s$BHxKPv7P9pdausyA>&8`>Kb9CQ+)cU6VbYz67*m9Y zV0BoYRbBS2<-?tA((Nk=X!I>nqH=~g7(L>*0(tIm|uDz}EsONZZN2(`FIYyu+6+ul3Z zh_NR_q{~aKGOlI?GF0^di5~+|jhyAf?nTO{SfWum>@%gX`{R*pMxfNI?Ik}U9@(n| zRixRyk%0cO@8X1E&_-L3jq3san4a@;4{lgu1zj{c1UKkEoMZU3btUa1)=7CDl5YNt znIkD6cgdJ~luA7)`vHX3l16Noj;fiHhZM-w&IK4pk`2O&;#qc?df4rzWldZ!%SCD4 z6cKU>*(6f?#FuIaGlN|n2;~;g(e76K>QQ$ICl^~wh zJ^J_we(~~FY2S4{6fQ|ba#2+}(>?2{&?A1jDgJ<>+LKqpT}xMp^OT5p^nBjNHUhaP zQHN9weEyte%vz)+D6TTS9jKP zGUZiHc=g_XnV4%jp(*dD;ni(Y;f+^Xc?iWnfSSajjpJTTONthuww9GYREMX$x7U3P zu!~iw|NBu9f)vZD{=sj2%}nfMiJviyKD65 zouTkCJ^!dG5LM%BRcjz6G*J#OkJ=jmp+Rtbvklh@dXVZDju~Cd#@#(s z#a*L&l~J4N>a=gQr>j(NHl7PKirRFi4Ewl_Qg>R)9|WD=WQ!|Qv#rwLN#ks`?2D+n z)8^h&s~Uu@tCA*#Vj34r5DZPwww1P@T5A7_?pbhed5hLocZgf{@r7%|k31vS@$jAz z@@Q#FY0N#*fKkBY417VX!02M3ImN)|?n|!?t+P>5=f9?*0b8sX{O1qa-k18+Q3`hg zBjc1Wnc^D761SV(LsTV!r*G@It^+>|H&RT})UBOzZlc!D*cn|gP*4SOnS^PRlS@0b ztbsx+Xg7J9t4{i*mf&b9-c8ze5^$YlD1N}Tr&3+&EIjeLAO>=)R&UM@Y9)zecb3kfDR)>NXzMR*F0N>2H8fs1W>lk zD|mzU8y~M^UYM|AKyY*QjCg?NmP=jJEZo6HyFS2!Kz6JF%A)~P{FL8La5^?lAE@q; ztQIi>a+PrhgT37B>29h2%mP-RAZL{D@6-N6NPV2Nt>Bm(G#efDB% z@0jtE^>t!4sACqGBS4E$Dq?p9{?T7Ew$J`45Qdbip&`?gkl-=YL zzIdxEoypTml|sq|{Ybz|NqPmij7z^9qPNWqP~tx0#sT{cY-g+>xCx{#6fi{G_U$YBZc;{3A9`n+S$=`@a->)q?XCl7=_$lY z{C)tm*T@~BKt~I}Eha?ttWU@6Y@(iH2{=1~^6yjOR&e%|batO=_JG&% zg0YH%VvNl_{t~e71(Pk9-6oN6LTW#)bsr(@o)fK2Ii6=*jvI|G1JbiXMj`1J;J*gV z08cLgt&K*fdCP|DfA^fa%PeN!y)yE9m=Bi!Z+Y?4Mg?L|^% z!E$0ML+oq3XE#J08NhC=5sv`k_GE4wcGEVMmrcu?rM0&o3YvOsJ4m<}xVB>E<-dm> zh}OZF(ztz$+T~)OVqGCj zQ{dw{vz}5a=fFd3l}~T~U%f_Qy;Lj$#bT*RWO7OS)?J@3sFDhwk}hL}(0nA;mNE>` zYmb9#`$^su@SqP6GjPcnM>~ztDWLSr5HM{ED4Hg0dEW^-N&Yk+lPCUBWpXh|t)2HT zU9T4@i1ald|L2z1f5iBDB9r@oD)H5LwrkxSzjG;9x;0AOd&j0UMsDUB(^JZ|YO8xT z=#x_O``>rFQR6e;8csQ7vU7Otl(8av$&JI48+c=WgrGAcHD>*2^Z z_w4keu7Zvv#6;hVerH=uaDY)>328U>(Z4pc?L@N~sIO;<3h$FTETA&p624l?4up0u_I}aI-zJ zz*}`GW6c<8FOjx?cxMLtlM=GKc5OJ}&}77{Mi*aDye{(8^qRRFmN7S~lb`YY`UA&H z_Pr@VJ6|16J@65xO#G28^O*vrhk9#DE|iRCu+Q8*)i1@wzdSPRd(DyF+=jm2e3|{h z)>EN@@7#>^(30JNP&fhDh1B_7rprw9UV!osdV>cT$45GZ9d&*e9E29qwwV%k5Q4!^ zK0_0^TbFY(KW5kx46~EpTjzdV~s)sBy6fCqy! z1pi|5<;QgHF6yCw-4^S~c1z%+aTBlC{^gXNgY?_w?stC@N?$2Py4s&MJ83Y?HH*T6 z8>EX5O7Y1b-zP3I-|cSGV0UTS_x95GeeFyaw5h`AK~o#&CcCvuqg-u7>`i%J*Si!` zYn%6TQBc&F@sct+H-G$e^MH8F2hWx|>rs=|=vDv{kzLNXas6EaRRBl<;3Oyav{M(N zJpY&i{zu*@02DB6>qsNL{M%Cp&YM@~{VeS%IQD-g`lF4YFq^mQEagPzx|JA(0kA;3 z>dx0Rw+dkJhpcRZ6HRA3Zk&EZUM(QkP7 zA2<5GxnG6+PHh{Z#==yUnBlhavr=ZKJxn`GDpah+UJd3{4wav zFljGZ&7UF-eNC@(X)c=?d!fL62WF1{$vqs|P&!0>aFKioj{fq)3eFhqu}_NL1o&Dh zJVl5d`$Pc6Bs0(yT2xO%%sTe*32ys`e)a`6&09)08Bazn#M&O(k88JId$G5*U&)pB zJtv2GY3iR{MbTYXpMrMNG$Ptm`ZXcWY~=Qm2R;WxX5tSMAsoy6*UQ<&SOq1l zuNPQ18|!Kc3I(R1P+QdP9Zcyu+4b(bOZSp`_c@!C5zq3u#}+Njnxu7J-OiGARlGI1oW3J2 z{K5`%gzBPBEX^bEE}!M~wz8lO0{|s0joz&p2rzpm+0KgEVrt4rtgUm+yQSOblU>bm zT6pDV=a!`EI+1fgW~X)7$l|((WE)z(RB_xww-}m)$=s7((ZxK<$F%=UOzsxnQtEUP z-(}4ndLKnq*V)EKwpC1cbyPb#I}Z-#J+rl4(?0K|z2p9m{NJ2BB>eu?tM3ij=Xd>a z?@{Wii_3d7J7fLYkDg!uV(Z;cAMJNIAOGRe=)fcUnQD4rtF!c|xmB=m3@y+`jN;ws z!V!r~JL6f-xEh^RQ%3DVwri`+%0Ddy$Fg9f`XjrwS`hWRrKPIFhjA6^vPZAJ9X6W6 z(z<9#pYZ6il;MM4TJ#C%kZ@%fgP+24c3PyhA-aIWwW4@~H|kZ>`HUYsH8%9oBDc+5GyM32ojxdD zhv(IV--`A|d~TCI^*#r+>VpyY zOSOf&^mqX=C5Viq!N)v*g9XBzolgJiq?n=@2|-I8Odb`fe~9VR9Y+0fbRjmeOY*2{ zVh@maaL#a30^fP8_%b(gRC@+?%E~^SHRm9fS1U1Dy7teV$-Ys4_J#*k6IMt6u|KTF zl3OcKROBLVH*jz6HyuD%Xy|Q29vpqC$ygUt_T;Pm9vgsMo~OdvWgA&w1vTy!_%i`s zrB$ef*Ue1S=3o-V5uNMPBa)j2Ma`b)y$U4aCO*>@?R=x{8F?!w*zW znd=r~B6TcpAFB4=(0dYbx`DB3L~mnB*5?(NEV4f~pQhZ!3O&cM#2~HaXaa^exP?qn zoSUyJzUDe_$v$1}`5VylH0<7W>6eaC3Z}>;sLcWeD(J#nMPNX4o}f}$jZf+HIi}O1 zp0*zUX$3o8nTJ*Bl02sj`_`i}(M7BV|3q1rUheIpA2WqrDrM^C%?>9;)s;D3sH{=7 zVQKF{aic-+CjxDfbqP4WD%8C%1znn=_h7o>W2{D}&&+E7F+DrSU4xbn>O2ZFFU6eP zpPnzLusl86nL(eY;xhZ0c0CqrqL4TG204a|CYI1b_9V)OKMl%Zfx0rZp^$tpd@!x&+a#q*`B-g{vEvVOw-im|Y9+|DJStY~L3o zdMbq+RT;1esKpCxQA*O4a>Jr_(v73h<}UH$7aYG5nO%yZTYMC+^={SkBy~pHQ3d<9 zb_tgWG1)%=YQv|uFyMGqtJ4g#%F%`G<$VIoNTy&1!gPR4wqIST`O=}~1F!N)u zJ0s<;8MowG7D7)MoM(mkv zL0Z`?@jbK3x6un(V}0h{MW~nc7`ZYlWiUygxy8B2j^07>z^Cel=yROE3XJ?LTb5*! zMoKMARjmQ)^ZFi{Bg*7lAy$d9a}P3nMv+AV6`l-P9p0%eUE)w$W%D_d?G@n)A)REewS!iU%Tfw z3N7nBY)>~r@#Ds(6tkBn>vA2TZ4xBoW$uS!HTV>T&hch__TVh!Mc8L6vps|*^vVdo zTzJV;7t)0}gr50MP^w+Q`NWs4cp1T8Fr)hShNQ+npykEJk3yQ% zNyCBsc@MLMw$6St?xG>&qhX_Fdh)KHXiHkTL_KE>u+)bZpvk>pI28jg1g_35vUH1B z6V6^_U&(X+QH9)Lh2eGQ-H|g+a;fVE1sJ{{ zFd{quQfY%EgnAdo zMYTAjSE&3yk=IR2%#LM)YVE&<+8o1ahbh`&3eXTk4l}W!BSRb<1(uJF@ast`{lK9~ z;#(DAbDQ8FB>4ZN=#K*EFoC_oQTC94P4OFYHV`D@2UmL9?n zjB5A~&hi`e#~bOkLv+8EfwXD)lf;1)z`+nE_qUn1B@k6v*uN|8YiY%;K(yo@DjCk4 zW)IX@ka-yTYG3jsjSd6+r}w|eI>1>++;H9(m2E-Zjv%Aj;^r;rsKTy&)6XH>`Boj* z17I`h&+8VxGKtkm@aoAKZcr{Sl8N9d9(_8DkReE;g4tN1CRz|lA{E};3Ozb-zo^C{ z=%QDQ>Ump8{td!$mu0aX8D1qXaG}74Z0`19cYLDq{zgc)R`G%%*SQ%)R;>9|WSF8Uy{pCDm0%X$rbpdc80bkNb zZ6WLw!?uZ^@v4gWHwpV&p?t|}w1a%9)w833%nAjc;n!8-5d#Wd6lpgD@Fe>g?BKQv z0Im#%ETfQyM58R72Ny?<^Efx^(9@}QPk3vuKuUXS_|Nq)^A)gNeO!{V8@^&uI^ zE*$DqTra1k%u~eLj{>NF0luj<5fW{MGZjhWcaX+fB%Xj1&%o^p&I-%y#fv!wifsiJ zj(;*TumVxgyL(BNiG~93$V`%TT>-ApU$oBV1d-4J{Lg>t;h>~<OT<(>W55jT4MVqFkvRVkI`)ofc_ooh!a9Ddy?ExI9{V#RoYR+?CX@#o63}}LQ&NG^x8ED* z!72+^S>X6P&Y!}Uj{yRd!{4HQ=zC#+)-+g%133oR)`nlGMFSkfsNS_i*M137QH^ruAP}kQsl`zdsG)H{bQxN*5w zckZD@&`zT_6l-qmKnuvED0*iU&AURN3v55E=z5Sze=$DJu`oFo((Wd{USdfl7oV{Z za3Rhu4u>->yGsak)RGlv$znFF_W2x%g;k74l_*GM?Ux&fD>T5Gp6Rz;;KOd($~-e` ziu8W#-h=rT$QIaJtcWq+pZi}e*g0F4mm)Bn5xnZ@X5ZuAdv|vVut{w15>t#=Eu1C= zSRnRF&v0Sa+MZp|ZW*8kK{aEl6xi1XEKpzcV)qQMU`#x;)a(wVkvJ1h}lW<3vq@d%d`0Mx9HMG zH0fKycvk)48wBTVt70A5u#Vu5w|%-!Z(c#lwY1=tUa<#1AhEMON(jDshq|?%C-n%- z#CSN&^_jLk(PkW2+$#8Rlzcw?4m!EHe`cNU@`vo@vht2YmR$_=Vp76albrJI?0p6N zJ%!1eHSoX$zgWTdzJi-enac)hq>-HY%qc={Dg+Po44xE)hj^2R?Wf)dK5ga8-z(Pq z$*pSN{e@%T(ZJI-A(yp%PMKM3$1&PQlmzWPlTRlBSL*^rpezoCca_}2euftHu8EkV6q zkW6xQ6~SGi;Oh7{extKaEyVl+Q0=E2C5f`AhuL~v z-!VZmX;Vh0^X{%cJ#NlY{Ky<;|0RMWgyB{}n@huI$`bw#g+xtW8UmnS(ImeMd<_#@iO@%?88q?oVRB=jpVT!J#vra)c_GYMK+cC^s5%eyD%7_KvK@y9jf4U zrQe0U_&_ceFeG46-oKN)OI$DU2&Vw{Iy!sjZH1jTC;*Ihd4Qk*$bJRD9eVx-3pXo@ zm-#!t>nmv|*|-bm$~K`L3N%*FEw=3LQuw^vht9Zp!7M4Vg-FZo;@(!ZWM4>?N1}2Z z@jipf6`bMo7Y;5#9&*qwpz3ha&Mj)R^~VhrfW78G*7YM3s-s9G0slq^Oea%S%pB1#V2)ve@jSE2{oYn$}(L%vw7(P13_uByY~4`Rq<(#+g}19jzG4#(-6>S- zmyO#no{|eWiF2Ryxc*}`< zto7V=Hiw3p#pC6T4^w9!VqOe36|;=w?`;0SKS{^v!^|4cXZMjO*@#FKwqG&w-&o29 z=+LVfJ^!G-=j`@3s;}%$|H3mVY+lj4863TlgdB{A#}N8W1b5vL-Yo^M zmwX3+Z$8*n$Qt6aj9d}Lm3JiVBapm*5&Nh$g%)LRUxxAI{4L zc%uML7DsgZ$@P(Ye`|DoceehRM6#zmH>34_UC5lmrx91Gc`l)@3IlmYmgE|lHBhn7 zs{j_waxXR*hX1FX%7SicGmN+ECE)>?2Es>e#|x^PERlBCRrB4ijOiy@_8)Lb2| z$~7+Y*u4isp*t!&x7Ec9wp~vAPUN%nAkgnKySCvg=~zI&LoeSmcBi#BFu&%zPw>*) zWX_dVz9n{7qvXdOwGWyWm5=ovPbjj%@YuSL{(t@WLZ2n73NL#T%J<1v0)gI%i}~Ti zZ6HQ<=8mgR+2r)lv&HGnrigMZaD}GF6%X8k2HMS@*(Z#a)tjhyez%XsEUXm$cy+YQ zHGR)8kfCzdo_w}*5Bq|Hsq=u_uNUXM*su^0dzWL39WSZrPc@gk+<3x1slLIjB!2MF zJqO45mS4oOfz_Vwi!cWEwxnK|z)vtJ60u!l7MU=b7on-#8F=|eUYN;iajneUpm7iM z_qyX!{#007wgbCdTRC`5FLN(y5;3**(-wrEnpw2iqc7sciQS2sh;Em>;+AUBp+a{ zIk0Ov`z=QpVPM>)LwhEAzAP|wK!S2OgK~*qCx^Y0$NJSSYj{S;HK@jSi%jB6G_b0bE38>XJA5w`CYbP{D^6bS|k%MNU&RUQ_;f< zP)HmFGm^XJ9~@eZfr|C-tC4LfEv!6 zR$UiY{;gicLD;#g*VA_=p4cZxm1>I0Dh?r@7dFqU^{6yu)Y_NwWQq1W>0iRcbL(C3 zSL$;6#n{k;+~)1KZfCe{jV3$0g!z*OjW~fmXKI@+z%vT72(Vgx06hs`km=QR@V@@KXmf_*Ijays(zIQhkD`PN#7j zzl4Uh8Q-un)P$o{y{JurGs;KYR4y!(Rj; ze7?_k9{T=l_NZm{1|+C!Z|6G?k8&e#-{0MyRn}S$1`eLp_Ryn`yz1l&q8k~1&J%t0 zLfvPG3{$&SS!zOoJY++(s>`HAB&}Xd-#F#;mZ)_be@%s(!6#YDfY5e5SEYt=cULrwo3XRu*fqb;)r-4W3BV;!JA zRyxF;V`T|CS{s(5JQ(eVfk9E6N zO4J?(f)m)O8XDdl8h8)jOPHa*vKfN)_ff^Jp?qKxK>^7O1(nF=HdJIgC3 z)WVwQmA$5sW%%Ao2Q)r$@y{Z^z-FS&}7qqFJ+A-CUY?>2lAURiX zwr3_W19?8VFDmw{b<)FOWrL7s9gO?~+>Lm@6!Y{hK7mgS?cX_e(j&zQefOo`-bzbC zgNi<$5SVL|Iz^SQT=~F@NmP{e5`3*h^EfesPus}2SzC3#MtY(7Z4ftG@%U$HAWCG9 z9wu>5p9&-Q_}?4WvLn#GsoXEGf8XuVjA_$5Nd8xLYk75shSL}Hs!%k+-GZroL+)@= zQH{OTLc2ejhf;0s)QZ1kX&B3)iShLvwR+zUXo;*MPvy71W(c%H>YH1^9aUR3`oN(t z7b#A^E0Z9+>d)@v27fCWmedlG)pN>-lVKDl^&0@Id1tR1&Ah#%tk6;n#+r;m#s7qFJU2p|i%AC6nNR7(p;N#o|30Pejc149 z<;J9q*-~Kz0ePA!WP?@BPn57-u@+8&61v+f2MvNfaE4Pn19Z764b=-1OL-dxyIl{;FT{<|yuCWct|rm= zh@|%0BD+N!@r(I$lFsCICh(MU)xp*l-kQ!_=Nmf6HG8Ng0$;MU5ZAb;5YTEP!R+N( z=0EWFbkBneez2&|9%;R_tbZ+9s06XZT)6W7*5X+t1L!7L?q=}k91=QrQ_tCZ?cA=@ zow~0IKA*Eh>%IQA;+c&DwOVCJurKLQ2e|*Tt|~==k0{wss;3geAJ*6NR>m=#ncGhq z)ci7a*fY@(w;c3zd8TJC!a&sW3_5bA?`gn)`jnlW5GMcDJSCit ztw-J7c0;*xXrYxOHL+f-NSmwVgyHN0D=SaCYQ3HnYrr<@wQ+K%ry0=NY-q19gB1;S z%ZCpPu-4F=ZOY`skM}+DUee+%U4yZdWR^_H$$7-FkAO-It5@rw^-3hznkKWZ4l07L zlp$-Cl7M4XYt6794nvP-70`(e(40bLWrAc?JODGOI-B_nq%)HNWi2#tw$f+{j_d}Y z^|XEIBhE6RrwD1CqoWL*M$Vd#A34l3 zK#>?63P3Ugi>Hfru7x7l9Ek>+iMIk042lFI%b4^9&LY`R?H>%mpg~>)A>?|3KD<~5 zEtbU76NkT)e!z5!-#^5TP;9CcA`JX={z}OX3H#~}mwXADF%P-I5$}+oo6fP54!JHO zHqG-!T4ixP4U1+~EPM}KbZ#hk*^iAcS4j2{$k&lbvZZE+9*URDO474a{^mTH2%=or zi*$2JTxGu99$@EjB73?|cM%buxQ+cCCXmnr2|{lDq$`UijI3&-cj^#(l?m^ zGZe9_R5S6dvY_3V_v@<#!`0MlXC!yhH$>8WvUG#v(Aw9|;UPL{kY2j~2XP;WCs@~Q zl(S9V3}_a4*aGcXxiy~ww^+a;0?kow%ws?gM+>Djvy>Jnq=B;o3wo~F5-dz*0tIno zl=lia=s9yWu4p?gDaD|rg>b=HBnNd8tmu9DWK#TOIlzCd{+de9fP;) zSKvxyfgG0L@H$0nWz?!d$*N7jMprB~!R)OsJ&J09vG&onm0;1MuZ*3^Em7 zXh9F@B~iHWZll|8=Qx}4#TzYkf8}x3o3j%Apj(p=46uR~oZy)G^So-j0cdMJXMQXs z$%VK2ELqUt+?35?rZ9i`p=c6vDT>~L!Cww?vJ7A>#tO0kp)_k({`^HsV5uI8wv

    + ## Introduction This repository is a fork of the [mmdetection](https://github.com/open-mmlab/mmdetection) toolbox with the implementation of OA-Mix, @@ -17,6 +18,7 @@ The method enhances model robustness against domain shifts by generating diverse For more information on the details of OA-Mix and its use cases, please refer to the paper [Object-Aware Domain Generalization for Object Detection](https://ojs.aaai.org/index.php/AAAI/article/view/28076), presented at AAAI 2024. + ## Example of OA-Mix Below is an example showing the results of OA-Mix: @@ -25,27 +27,34 @@ Below is an example showing the results of OA-Mix: + ## Performance Improvement with OA-Mix Below is a performance comparison between a baseline object detection model and the same model with OA-Mix applied: -| Model | Dataset | mAP | Gauss. | Shot | Impulse | Defocus | Glass | Motion | Zoom | Snow | Frost | Fog | Bright | Contrast | Elastic | Pixel | JPEG | mPC | -| :-------------------: | :----------: | :--: | :----: | :--: | :-----: | :-----: | :---: | :----: | :--: | :--: | :---: | :--: | :----: | :------: | :-----: | ----- | :--: | :--: | -| Faster R-CNN | Cityscapes-C | 42.2 | 0.5 | 1.1 | 1.1 | 17.2 | 16.5 | 18.3 | 2.1 | 2.2 | 12.3 | 29.8 | 32.0 | 24.1 | 40.1 | 18.7 | 15.1 | 15.4 | -| Faster R-CNN + OA-Mix | Cityscapes-C | 42.7 | 7.2 | 9.6 | 7.7 | 22.8 | 18.8 | 21.9 | 5.4 | 5.2 | 23.6 | 37.3 | 38.7 | 31.9 | 40.2 | 22.2 | 20.2 | 20.8 | +| Model | Dataset | Backbone | mAP | Gauss. | Shot | Impulse | Defocus | Glass | Motion | Zoom | Snow | Frost | Fog | Bright | Contrast | Elastic | Pixel | JPEG | mPC | +| :-------------------: | :----------: | :------: | :--: | :----: | :--: | :-----: | :-----: | :---: | :----: | :--: | :--: | :---: | :--: | :----: | :------: | :-----: | ----- | :--: | :--: | +| Faster R-CNN | Cityscapes-C | R-50-FPN | 42.2 | 0.5 | 1.1 | 1.1 | 17.2 | 16.5 | 18.3 | 2.1 | 2.2 | 12.3 | 29.8 | 32.0 | 24.1 | 40.1 | 18.7 | 15.1 | 15.4 | +| Faster R-CNN + OA-Mix | Cityscapes-C | R-50-FPN | 42.7 | 7.2 | 9.6 | 7.7 | 22.8 | 18.8 | 21.9 | 5.4 | 5.2 | 23.6 | 37.3 | 38.7 | 31.9 | 40.2 | 22.2 | 20.2 | 20.8 | + +The model was evaluated using the [robust detection benchmark](https://github.com/bethgelab/robust-detection-benchmark), which can be run using the [test_robustness.py](tools/analysis_tools/test_robustness.py) script provided by mmdetection. + ## mmdetection Readme For information on mmdetection please refer to the [mmdetection readme](MMDETECTION_README.md). + ## Installation Please refer to [INSTALL.md](INSTALL.md) for installation and dataset preparation. + ## Get Started Please see [GETTING_STARTED.md](GETTING_STARTED.md) for the basic usage of MMDetection. + ## Citation If you use this toolbox or benchmark in your research, please cite this project. From face6eeb017bc5f2bf5619c8e10b5d0d28a22f92 Mon Sep 17 00:00:00 2001 From: Dasol Hong Date: Wed, 21 Aug 2024 10:17:32 +0900 Subject: [PATCH 5/7] Update README.md --- README.md | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/README.md b/README.md index 2e7fc021a47..aa532eb9d0a 100644 --- a/README.md +++ b/README.md @@ -40,21 +40,6 @@ Below is a performance comparison between a baseline object detection model and The model was evaluated using the [robust detection benchmark](https://github.com/bethgelab/robust-detection-benchmark), which can be run using the [test_robustness.py](tools/analysis_tools/test_robustness.py) script provided by mmdetection. -## mmdetection Readme - -For information on mmdetection please refer to the [mmdetection readme](MMDETECTION_README.md). - - -## Installation - -Please refer to [INSTALL.md](INSTALL.md) for installation and dataset preparation. - - -## Get Started - -Please see [GETTING_STARTED.md](GETTING_STARTED.md) for the basic usage of MMDetection. - - ## Citation If you use this toolbox or benchmark in your research, please cite this project. From 4ea8171fb450a3ce8d73d19c74720c01efd4241f Mon Sep 17 00:00:00 2001 From: Dasol Hong Date: Wed, 21 Aug 2024 10:29:46 +0900 Subject: [PATCH 6/7] Update README.md --- README.md | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/README.md b/README.md index aa532eb9d0a..05298c170df 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,38 @@ Below is a performance comparison between a baseline object detection model and The model was evaluated using the [robust detection benchmark](https://github.com/bethgelab/robust-detection-benchmark), which can be run using the [test_robustness.py](tools/analysis_tools/test_robustness.py) script provided by mmdetection. +## How to use? + +Modify the `train_pipeline` in the configuration file as follows: +```python +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args=None), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='RandomResize', + scale=[(2048, 800), (2048, 1024)], + keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + dict(type='OAMix', version='oamix') + dict(type='PackDetInputs') +] +``` + +Alternatively, you can directly call the OAMix class as shown below: +```python +import numpy as np +from mmdet.datasets.transforms.oa_mix import OAMix + +# Generate random data +img = np.random.randint(0, 256, (427, 640, 3)).astype(np.uint8) +gt_bboxes = np.random.randn(3, 4).astype(np.float32) + +# Apply OA-Mix +oamix = OAMix() +img_aug = oamix({'img': img, 'gt_bboxes': gt_bboxes}) +``` + + ## Citation If you use this toolbox or benchmark in your research, please cite this project. From 50ed63a9c4db89a22f1c47dd46d29fe422343a7b Mon Sep 17 00:00:00 2001 From: root Date: Wed, 21 Aug 2024 01:59:12 +0000 Subject: [PATCH 7/7] fix(pre-commit): resolve conflicts from pre-commit hook adjustments --- README.md | 7 +- ...faster-rcnn_r50_fpn_1x_cityscapes_oamix.py | 21 ++--- mmdet/datasets/transforms/__init__.py | 92 +++++++++++++++---- mmdet/datasets/transforms/geometric.py | 80 ++++++++-------- 4 files changed, 125 insertions(+), 75 deletions(-) mode change 100644 => 100755 configs/oamix/faster-rcnn_r50_fpn_1x_cityscapes_oamix.py mode change 100644 => 100755 mmdet/datasets/transforms/geometric.py diff --git a/README.md b/README.md index 05298c170df..6f750478492 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,6 @@ - ## Introduction This repository is a fork of the [mmdetection](https://github.com/open-mmlab/mmdetection) toolbox with the implementation of OA-Mix, @@ -18,7 +17,6 @@ The method enhances model robustness against domain shifts by generating diverse For more information on the details of OA-Mix and its use cases, please refer to the paper [Object-Aware Domain Generalization for Object Detection](https://ojs.aaai.org/index.php/AAAI/article/view/28076), presented at AAAI 2024. - ## Example of OA-Mix Below is an example showing the results of OA-Mix: @@ -27,7 +25,6 @@ Below is an example showing the results of OA-Mix: - ## Performance Improvement with OA-Mix Below is a performance comparison between a baseline object detection model and the same model with OA-Mix applied: @@ -39,10 +36,10 @@ Below is a performance comparison between a baseline object detection model and The model was evaluated using the [robust detection benchmark](https://github.com/bethgelab/robust-detection-benchmark), which can be run using the [test_robustness.py](tools/analysis_tools/test_robustness.py) script provided by mmdetection. - ## How to use? Modify the `train_pipeline` in the configuration file as follows: + ```python train_pipeline = [ dict(type='LoadImageFromFile', backend_args=None), @@ -58,6 +55,7 @@ train_pipeline = [ ``` Alternatively, you can directly call the OAMix class as shown below: + ```python import numpy as np from mmdet.datasets.transforms.oa_mix import OAMix @@ -71,7 +69,6 @@ oamix = OAMix() img_aug = oamix({'img': img, 'gt_bboxes': gt_bboxes}) ``` - ## Citation If you use this toolbox or benchmark in your research, please cite this project. diff --git a/configs/oamix/faster-rcnn_r50_fpn_1x_cityscapes_oamix.py b/configs/oamix/faster-rcnn_r50_fpn_1x_cityscapes_oamix.py old mode 100644 new mode 100755 index 6517e3b83c1..973f072c84f --- a/configs/oamix/faster-rcnn_r50_fpn_1x_cityscapes_oamix.py +++ b/configs/oamix/faster-rcnn_r50_fpn_1x_cityscapes_oamix.py @@ -4,14 +4,6 @@ '../_base_/default_runtime.py', '../_base_/schedules/schedule_1x.py' ] - -# OA-Mix -oamix_config=dict( - type='OAMix', version='oamix', - box_scale=(0.05, 0.3), box_ratio=(3, 0.33), - sigma_ratio=0.2, score_thresh=10, -) - backend_args = None train_pipeline = [ dict(type='LoadImageFromFile', backend_args=backend_args), @@ -21,14 +13,17 @@ scale=[(2048, 800), (2048, 1024)], keep_ratio=True), dict(type='RandomFlip', prob=0.5), - oamix_config, + dict( + type='OAMix', + version='oamix', + box_scale=(0.05, 0.3), + box_ratio=(3, 0.33), + sigma_ratio=0.2, + score_thresh=10), dict(type='PackDetInputs') ] train_dataloader = dict( - num_workers=8, - dataset=dict(dataset=dict(pipeline=train_pipeline)) -) - + num_workers=8, dataset=dict(dataset=dict(pipeline=train_pipeline))) # Model model = dict( diff --git a/mmdet/datasets/transforms/__init__.py b/mmdet/datasets/transforms/__init__.py index 8e4ba6d2880..022acc165ec 100644 --- a/mmdet/datasets/transforms/__init__.py +++ b/mmdet/datasets/transforms/__init__.py @@ -1,19 +1,19 @@ # Copyright (c) OpenMMLab. All rights reserved. from .augment_wrappers import AutoAugment, RandAugment from .colorspace import (AutoContrast, Brightness, Color, ColorTransform, - Contrast, Equalize, Invert, Posterize, Sharpness, - Solarize, SolarizeAdd, Invert4Mix) + Contrast, Equalize, Invert, Invert4Mix, Posterize, + Sharpness, Solarize, SolarizeAdd) from .formatting import (ImageToTensor, PackDetInputs, PackReIDInputs, PackTrackInputs, ToTensor, Transpose) from .frame_sampling import BaseFrameSample, UniformRefFrameSample from .geometric import (GeomTransform, Rotate, ShearX, ShearY, TranslateX, TranslateY) from .instaboost import InstaBoost -from .oa_mix import OAMix from .loading import (FilterAnnotations, InferencerLoader, LoadAnnotations, LoadEmptyAnnotations, LoadImageFromNDArray, LoadMultiChannelImageFromFiles, LoadPanopticAnnotations, LoadProposals, LoadTrackAnnotations) +from .oa_mix import OAMix from .text_transformers import LoadTextAnnotations, RandomSamplingNegPos from .transformers_glip import GTBoxSubOne_GLIP, RandomFlip_GLIP from .transforms import (Albu, CachedMixUp, CachedMosaic, CopyPaste, CutOut, @@ -26,21 +26,73 @@ from .wrappers import MultiBranch, ProposalBroadcaster, RandomOrder __all__ = [ - 'PackDetInputs', 'ToTensor', 'ImageToTensor', 'Transpose', - 'LoadImageFromNDArray', 'LoadAnnotations', 'LoadPanopticAnnotations', - 'LoadMultiChannelImageFromFiles', 'LoadProposals', 'Resize', 'RandomFlip', - 'RandomCrop', 'SegRescale', 'MinIoURandomCrop', 'Expand', - 'PhotoMetricDistortion', 'Albu', 'InstaBoost', 'RandomCenterCropPad', - 'AutoAugment', 'CutOut', 'ShearX', 'ShearY', 'Rotate', 'Color', 'Equalize', - 'Brightness', 'Contrast', 'TranslateX', 'TranslateY', 'RandomShift', - 'Mosaic', 'MixUp', 'RandomAffine', 'YOLOXHSVRandomAug', 'CopyPaste', - 'FilterAnnotations', 'Pad', 'GeomTransform', 'ColorTransform', - 'RandAugment', 'Sharpness', 'Solarize', 'SolarizeAdd', 'Posterize', - 'AutoContrast', 'Invert', 'Invert4Mix', 'MultiBranch', 'RandomErasing', - 'LoadEmptyAnnotations', 'RandomOrder', 'CachedMosaic', 'CachedMixUp', - 'FixShapeResize', 'ProposalBroadcaster', 'InferencerLoader', - 'LoadTrackAnnotations', 'BaseFrameSample', 'UniformRefFrameSample', - 'PackTrackInputs', 'PackReIDInputs', 'FixScaleResize', - 'ResizeShortestEdge', 'GTBoxSubOne_GLIP', 'RandomFlip_GLIP', - 'RandomSamplingNegPos', 'LoadTextAnnotations' 'OAMix', + 'PackDetInputs', + 'ToTensor', + 'ImageToTensor', + 'Transpose', + 'LoadImageFromNDArray', + 'LoadAnnotations', + 'LoadPanopticAnnotations', + 'LoadMultiChannelImageFromFiles', + 'LoadProposals', + 'Resize', + 'RandomFlip', + 'RandomCrop', + 'SegRescale', + 'MinIoURandomCrop', + 'Expand', + 'PhotoMetricDistortion', + 'Albu', + 'InstaBoost', + 'RandomCenterCropPad', + 'AutoAugment', + 'CutOut', + 'ShearX', + 'ShearY', + 'Rotate', + 'Color', + 'Equalize', + 'Brightness', + 'Contrast', + 'TranslateX', + 'TranslateY', + 'RandomShift', + 'Mosaic', + 'MixUp', + 'RandomAffine', + 'YOLOXHSVRandomAug', + 'CopyPaste', + 'FilterAnnotations', + 'Pad', + 'GeomTransform', + 'ColorTransform', + 'RandAugment', + 'Sharpness', + 'Solarize', + 'SolarizeAdd', + 'Posterize', + 'AutoContrast', + 'Invert', + 'Invert4Mix', + 'MultiBranch', + 'RandomErasing', + 'LoadEmptyAnnotations', + 'RandomOrder', + 'CachedMosaic', + 'CachedMixUp', + 'FixShapeResize', + 'ProposalBroadcaster', + 'InferencerLoader', + 'LoadTrackAnnotations', + 'BaseFrameSample', + 'UniformRefFrameSample', + 'PackTrackInputs', + 'PackReIDInputs', + 'FixScaleResize', + 'ResizeShortestEdge', + 'GTBoxSubOne_GLIP', + 'RandomFlip_GLIP', + 'RandomSamplingNegPos', + 'LoadTextAnnotations', + 'OAMix', ] diff --git a/mmdet/datasets/transforms/geometric.py b/mmdet/datasets/transforms/geometric.py old mode 100644 new mode 100755 index 9e76bcc15a3..eaa7eadcd67 --- a/mmdet/datasets/transforms/geometric.py +++ b/mmdet/datasets/transforms/geometric.py @@ -756,19 +756,21 @@ def _transform_seg(self, results: dict, mag: float) -> None: @TRANSFORMS.register_module() class BBoxShearX(ShearX): + def _transform_img(self, results: dict, mag: float) -> None: img_orig = results['img'].copy() (h_img, w_img, c_img) = img_orig.shape img = np.zeros_like(results['img'], dtype=np.float32) - for idx, (bbox, mask) in enumerate(zip(results['bboxes'], results['masks'])): + for idx, (bbox, + mask) in enumerate(zip(results['bboxes'], results['masks'])): cy = (bbox[1] + bbox[3]) / 2 - shear_matrix = np.array([[1, mag, -mag*cy], [0, 1, 0]], dtype=np.float32) + shear_matrix = np.array([[1, mag, -mag * cy], [0, 1, 0]], + dtype=np.float32) shear_img = cv2.warpAffine( img_orig, - shear_matrix, - (w_img, h_img), + shear_matrix, (w_img, h_img), borderValue=tuple([0] * 3), flags=cv2.INTER_LINEAR) img = (1.0 - mask) * img + mask * shear_img @@ -786,19 +788,20 @@ def _transform_seg(self, results: dict, mag: float) -> None: @TRANSFORMS.register_module() class BBoxShearY(ShearY): + def _transform_img(self, results: dict, mag: float) -> None: img_orig = results['img'].copy() (h_img, w_img, c_img) = img_orig.shape img = np.zeros_like(results['img'], dtype=np.float32) - for idx, (bbox, mask) in enumerate(zip(results['bboxes'], results['masks'])): + for idx, (bbox, + mask) in enumerate(zip(results['bboxes'], results['masks'])): cx = (bbox[0] + bbox[2]) / 2 shear_matrix = np.float32([[1, 0, 0], [mag, 1, -mag * cx]]) shear_img = cv2.warpAffine( img_orig, - shear_matrix, - (w_img, h_img), + shear_matrix, (w_img, h_img), borderValue=tuple([0] * 3), flags=cv2.INTER_LINEAR) img = (1.0 - mask) * img + mask * shear_img @@ -816,20 +819,22 @@ def _transform_seg(self, results: dict, mag: float) -> None: @TRANSFORMS.register_module() class BBoxRotate(Rotate): + def _transform_img(self, results: dict, mag: float) -> None: img_orig = results['img'].copy() (h_img, w_img, c_img) = img_orig.shape img = np.zeros_like(results['img'], dtype=np.float32) - for idx, (bbox, mask) in enumerate(zip(results['bboxes'], results['masks'])): + for idx, (bbox, + mask) in enumerate(zip(results['bboxes'], results['masks'])): cx = (bbox[0] + bbox[2]) / 2 cy = (bbox[1] + bbox[3]) / 2 - translate_matrix = np.float32([[1, 0, w_img//2 - cx], [0, 1, h_img//2 - cy]]) + translate_matrix = np.float32([[1, 0, w_img // 2 - cx], + [0, 1, h_img // 2 - cy]]) translated_img = cv2.warpAffine( img_orig, - translate_matrix, - (w_img, h_img), + translate_matrix, (w_img, h_img), borderValue=tuple([0] * 3), flags=cv2.INTER_LINEAR) """Rotate the image.""" @@ -838,11 +843,11 @@ def _transform_img(self, results: dict, mag: float) -> None: mag, border_value=self.img_border_value, interpolation=self.interpolation) - translate_matrix = np.float32([[1, 0, -w_img//2 + cx], [0, 1, -h_img//2 + cy]]) + translate_matrix = np.float32([[1, 0, -w_img // 2 + cx], + [0, 1, -h_img // 2 + cy]]) rotated_img = cv2.warpAffine( rotated_img, - translate_matrix, - (w_img, h_img), + translate_matrix, (w_img, h_img), borderValue=tuple([0] * 3), flags=cv2.INTER_LINEAR) @@ -861,6 +866,7 @@ def _transform_seg(self, results: dict, mag: float) -> None: @TRANSFORMS.register_module() class BBoxTranslateX(TranslateX): + def __init__(self, **kwargs): super().__init__(img_border_value=0, **kwargs) @@ -868,9 +874,9 @@ def _transform_img(self, results: dict, mag: float) -> None: img_orig = results['img'].copy() img = np.zeros_like(results['img'], dtype=np.float32) - for idx, (bbox, mask) in enumerate(zip(results['bboxes'], results['masks'])): + for idx, (bbox, + mask) in enumerate(zip(results['bboxes'], results['masks'])): w = bbox[2] - bbox[0] - """Translate the image horizontally.""" _mag = int(w * mag) translated_img = mmcv.imtranslate( @@ -895,6 +901,7 @@ def _transform_seg(self, results: dict, mag: float) -> None: @TRANSFORMS.register_module() class BBoxTranslateY(TranslateY): + def __init__(self, **kwargs): super().__init__(img_border_value=0, **kwargs) @@ -902,9 +909,9 @@ def _transform_img(self, results: dict, mag: float) -> None: img_orig = results['img'].copy() img = np.zeros_like(results['img'], dtype=np.float32) - for idx, (bbox, mask) in enumerate(zip(results['bboxes'], results['masks'])): + for idx, (bbox, + mask) in enumerate(zip(results['bboxes'], results['masks'])): h = bbox[3] - bbox[1] - """Translate the image horizontally.""" _mag = int(h * mag) translated_img = mmcv.imtranslate( @@ -929,26 +936,26 @@ def _transform_seg(self, results: dict, mag: float) -> None: @TRANSFORMS.register_module() class BgShearX(ShearX): + def __init__(self, **kwargs): super().__init__(img_border_value=0, **kwargs) + def _transform_img(self, results: dict, mag: float) -> None: img_orig = results['img'].copy() (h_img, w_img, c_img) = img_orig.shape - """Shear the image horizontally.""" mask_max = np.max(results['masks'], axis=0) - shear_matrix = np.array([[1, mag, -mag*h_img//2], [0, 1, 0]], dtype=np.float32) + shear_matrix = np.array([[1, mag, -mag * h_img // 2], [0, 1, 0]], + dtype=np.float32) sheared_mask = cv2.warpAffine( mask_max, - shear_matrix, - (w_img, h_img), + shear_matrix, (w_img, h_img), borderValue=tuple([0] * 3), flags=cv2.INTER_LINEAR) sheared_img = cv2.warpAffine( img_orig, - shear_matrix, - (w_img, h_img), + shear_matrix, (w_img, h_img), borderValue=tuple([0] * 3), flags=cv2.INTER_LINEAR) @@ -966,25 +973,25 @@ def _transform_seg(self, results: dict, mag: float) -> None: @TRANSFORMS.register_module() class BgShearY(ShearY): + def __init__(self, **kwargs): super().__init__(img_border_value=0, **kwargs) + def _transform_img(self, results: dict, mag: float) -> None: img_orig = results['img'].copy() (h_img, w_img, c_img) = img_orig.shape - """Shear the image vertically.""" mask_max = np.max(results['masks'], axis=0) - shear_matrix = np.array([[1, 0, 0], [mag, 1, -mag*w_img//2]], dtype=np.float32) + shear_matrix = np.array([[1, 0, 0], [mag, 1, -mag * w_img // 2]], + dtype=np.float32) sheared_mask = cv2.warpAffine( mask_max, - shear_matrix, - (w_img, h_img), + shear_matrix, (w_img, h_img), borderValue=tuple([0] * 3), flags=cv2.INTER_LINEAR) sheared_img = cv2.warpAffine( img_orig, - shear_matrix, - (w_img, h_img), + shear_matrix, (w_img, h_img), borderValue=tuple([0] * 3), flags=cv2.INTER_LINEAR) @@ -1002,18 +1009,16 @@ def _transform_seg(self, results: dict, mag: float) -> None: @TRANSFORMS.register_module() class BgRotate(Rotate): + def __init__(self, **kwargs): super().__init__(img_border_value=0, **kwargs) + def _transform_img(self, results: dict, mag: float) -> None: img_orig = results['img'].copy() - """Rotate the image.""" mask_max = np.max(results['masks'], axis=0) rotated_mask = mmcv.imrotate( - mask_max, - mag, - border_value=0, - interpolation=self.interpolation) + mask_max, mag, border_value=0, interpolation=self.interpolation) rotated_img = mmcv.imrotate( img_orig, mag, @@ -1034,13 +1039,13 @@ def _transform_seg(self, results: dict, mag: float) -> None: @TRANSFORMS.register_module() class BgTranslateX(TranslateX): + def __init__(self, **kwargs): super().__init__(img_border_value=0, **kwargs) def _transform_img(self, results: dict, mag: float) -> None: img_orig = results['img'].copy() (h_img, w_img, c_img) = img_orig.shape - """Translate the image horizontally.""" _mag = w_img * mag mask_max = np.max(results['masks'], axis=0) @@ -1071,12 +1076,13 @@ def _transform_seg(self, results: dict, mag: float) -> None: @TRANSFORMS.register_module() class BgTranslateY(TranslateY): + def __init__(self, **kwargs): super().__init__(img_border_value=0, **kwargs) + def _transform_img(self, results: dict, mag: float) -> None: img_orig = results['img'].copy() (h_img, w_img, c_img) = img_orig.shape - """Translate the image horizontally.""" _mag = h_img * mag mask_max = np.max(results['masks'], axis=0)