diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/agents/portforwarding.sh b/agents/portforwarding.sh index c356855..2f83df5 100644 --- a/agents/portforwarding.sh +++ b/agents/portforwarding.sh @@ -2,7 +2,7 @@ kubectl port-forward corenlp-7fd4974bb-8mq5g 4080:5001 -n chirpy kubectl port-forward dialogact-849b4b67d8-ngzd5 4081:5001 -n chirpy & kubectl port-forward g2p-7644ff75bd-cjj57 4082:5001 -n chirpy & kubectl port-forward gpt2ed-68f849f64b-wr8zw 4083:5001 -n chirpy & -kubectl port-forward questionclassifier-668c4fd6c6-fd586 4084:5001 -n chirpy & +kubectl port-forward questionclassifier-668c4fd6c6-7nl2k 4084:5001 -n chirpy & kubectl port-forward convpara-dbdc8dcfb-csktj 4085:5001 -n chirpy & kubectl port-forward entitylinker-59b9678b8-nmwx9 4086:5001 -n chirpy & kubectl port-forward blenderbot-695c7b5896-gkz2s 4087:5001 -n chirpy & diff --git a/chirpy/core/dialog_manager.py b/chirpy/core/dialog_manager.py index 9c0c473..52f6bb4 100644 --- a/chirpy/core/dialog_manager.py +++ b/chirpy/core/dialog_manager.py @@ -284,7 +284,16 @@ def update_rg_states(self, results: RankedResults, selected_rg: str): # Get the args needed for the update_state_if_not_chosen fn. That's (state, conditional_state) for all RGs except selected_rg other_rgs = [rg for rg in results.keys() if rg != selected_rg and not is_killed(results[rg])] logger.info(f"now, current states are {rg_states}") - args_list = [[rg_states[rg], results[rg].conditional_state] for rg in other_rgs] + + def rg_was_taken_over(rg): + if self.state_manager.last_state: + logger.debug(f"Rg that is selected is {selected_rg}. Currently evaluated rg is {rg}. " + f"rg == self.state_manager.last_state.active_rg is {rg == self.state_manager.last_state.active_rg}") + return rg_states[selected_rg].rg_that_was_taken_over and rg == self.state_manager.last_state.active_rg + else: + return None + + args_list = [[rg_states[rg], results[rg].conditional_state, rg_was_taken_over(rg)] for rg in other_rgs] # Run update_state_if_not_chosen for other RGs logger.info(f'Starting to run update_state_if_not_chosen for {other_rgs}...') @@ -331,7 +340,6 @@ def run_rgs_and_rank(self, phase: str, exclude_rgs : List[str] = []) -> RankedRe # Get the states for the RGs we'll run, which we'll use as input to the get_response/get_prompt fn logger.debug('Copying RG states to use as input...') - input_rg_states = copy.copy([rg_states[rg] for rg in rgs_list]) # list of dicts # import pdb; pdb.set_trace() @@ -343,10 +351,22 @@ def run_rgs_and_rank(self, phase: str, exclude_rgs : List[str] = []) -> RankedRe priority_modules = [last_state_active_rg] else: priority_modules = [] + + rg_was_taken_over = None + if self.state_manager.last_state_response: + rg_was_taken_over = self.state_manager.last_state_response.state.rg_that_was_taken_over + + def rg_to_resume(rg): + logger.debug(f"rg that was taken over is {rg_was_taken_over}. Currently evaluated rg is {rg}. " + f"rg == rg_was_taken_over is {rg == rg_was_taken_over}.") + return rg == rg_was_taken_over + + function_name = 'get_prompt_wrapper' if phase == 'prompt' else 'get_response' + args_list = copy.copy([[rg_states[rg], rg_to_resume(rg)] for rg in rgs_list]) results_dict = self.response_generators.run_multithreaded(rg_names=rgs_list, - function_name=f'get_{phase}', + function_name=function_name, timeout=timeout, - args_list=[[state] for state in input_rg_states], + args_list=args_list, # [[state] for state in input_rg_states], priority_modules=priority_modules) # Log the initial results diff --git a/chirpy/core/entity_linker/wiki_data_fetching.py b/chirpy/core/entity_linker/wiki_data_fetching.py index d60a454..b7c7f12 100644 --- a/chirpy/core/entity_linker/wiki_data_fetching.py +++ b/chirpy/core/entity_linker/wiki_data_fetching.py @@ -20,7 +20,7 @@ ANCHORTEXT_QUERY_TIMEOUT = 3.0 # seconds ENTITYNAME_QUERY_TIMEOUT = 1.0 # seconds -ARTICLES_INDEX_NAME = 'enwiki-20220107-articles' +ARTICLES_INDEX_NAME = 'enwiki-20200920-articles' # These are the fields we DO want to fetch from ES FIELDS_FILTER = ['doc_title', 'doc_id', 'categories', 'pageview', 'linkable_span_info', 'wikidata_categories_all', 'redirects', 'plural'] diff --git a/chirpy/core/entity_tracker/entity_tracker.py b/chirpy/core/entity_tracker/entity_tracker.py index 4cc8055..b6c87f1 100644 --- a/chirpy/core/entity_tracker/entity_tracker.py +++ b/chirpy/core/entity_tracker/entity_tracker.py @@ -23,6 +23,8 @@ class EntityTrackerState(object): def __init__(self): self.cur_entity = None # the current entity under discussion (can be None) + self.talked_unfinished = [] # entities that we have not finished talking about, but the rg is taken over + self.able_to_takeover_entities = [] # entities that are found in the response in that turn and can be used for wiki rg to takeover self.talked_rejected = [] # entities we talked about in the past, and stopped talking about because the user indicated they didn't want to talk about it any more self.talked_finished = [] # entities we talked about in the past, that aren't in talked_rejected self.talked_transitionable = [] @@ -97,7 +99,7 @@ def finish_entity(self, entity: Optional[WikiEntity], transition_is_possible=Tru logger.error(f"This is an error. This should be a WikiEntity object but {entity} is of type {type(entity)}") entity = None - if entity is not None and entity not in self.talked_finished: + if entity is not None and entity not in self.talked_finished and entity not in self.talked_unfinished: logger.info(f'Putting entity {entity} on the talked_finished list') self.talked_finished.append(entity) @@ -277,16 +279,23 @@ def condition_fn(entity_linker_result, linked_span, entity) -> bool: if nav_intent_output.neg_intent or nav_intent_output.pos_intent or last_answer_type in [AnswerType.QUESTION_SELFHANDLING, AnswerType.QUESTION_HANDOFF]: self.cur_entity = self.entity_initiated_on_turn + logger.info(f'Resetting able_to_takeover_entities to empty list') + self.able_to_takeover_entities = [] + for linked_span in current_state.entity_linker.high_prec: if not self.talked(linked_span.top_ent): logger.info(f'Adding {linked_span.top_ent} to user_mentioned_untalked') self.user_mentioned_untalked.append(linked_span.top_ent) + logger.info(f'Adding {linked_span.top_ent} to able_to_takeover_entities') + self.able_to_takeover_entities.append(linked_span.top_ent) logger.primary_info(f'The EntityTrackerState is now: {self}') + # logger.error(f'ABLE_TO_TAKEOVER_ENTITIES: {self.able_to_takeover_entities}') # Update the entity tracker history self.history[-1]['user'] = self.cur_entity + def record_untalked_high_prec_entities(self, current_state): """ Take any entities in the entity linker's high precision set for this turn, and if they haven't been discussed, @@ -313,6 +322,7 @@ def update_from_rg(self, result: Union[ResponseGeneratorResult, PromptResult, Up result: ResponseGeneratorResult, PromptResult, or UpdateEntity rg: the name of the RG that provided the new entity """ + if isinstance(result, UpdateEntity): new_entity = result.cur_entity phase = 'get_entity' @@ -325,6 +335,14 @@ def update_from_rg(self, result: Union[ResponseGeneratorResult, PromptResult, Up transition_is_possible = not getattr(result, 'no_transition', False) + if self.able_to_takeover_entities and result.state.takeover_entity: + self.talked_unfinished.append(self.cur_entity) + new_entity = self.able_to_takeover_entities.pop() + logger.primary_info(f'Removing {new_entity} from {self.able_to_takeover_entities}') + self.able_to_takeover_entities = [e for e in self.able_to_takeover_entities if e != new_entity] + logger.info(f'After takeover, self.talk_unfinished is {self.talked_unfinished}, self.able_to_takeover_entities is {self.able_to_takeover_entities}' + f' and self.talked_unfinished is {self.talked_finished}.') + if new_entity == self.cur_entity: logger.primary_info(f'new_entity={new_entity} from {rg} RG {phase} is the same as cur_entity, so keeping EntityTrackerState the same') else: @@ -340,11 +358,18 @@ def update_from_rg(self, result: Union[ResponseGeneratorResult, PromptResult, Up self.cur_entity = new_entity # Remove new_entity from user_mentioned_untalked if new_entity in self.user_mentioned_untalked: - logger.primary_info(f'Removing {new_entity} from {self.user_mentioned_untalked}') + logger.primary_info(f'Removing {new_entity} from {self.user_mentioned_untalked} after conversation is resumed.') self.user_mentioned_untalked = [e for e in self.user_mentioned_untalked if e != new_entity] logger.primary_info(f'Set cur_entity to new_entity={new_entity} from {rg} RG {phase}') - logger.primary_info(f'EntityTrackerState after updating wrt {rg} RG {phase}: {self}') + + if new_entity in self.talked_unfinished: + archived_entity = new_entity + logger.info( + f"Removing archived_entity [{archived_entity}] from talked_unfinished [{self.talked_unfinished}]") + self.talked_unfinished.remove(archived_entity) + + logger.info(f'EntityTrackerState after updating wrt {rg} RG {phase}: {self}') # If we're updating after receiving UpdateEntity from an RG, put any undiscussed high precision entities that # the user mentioned this turn in user_mentioned_untalked @@ -360,6 +385,8 @@ def update_from_rg(self, result: Union[ResponseGeneratorResult, PromptResult, Up def __repr__(self, show_history=False): output = f" bool: if ent is None: return True return ent in entities + + self.able_to_takeover_entities = [ent for ent in self.able_to_takeover_entities if keep_entity(ent)] self.talked_finished = [ent for ent in self.talked_finished if keep_entity(ent)] self.talked_rejected = [ent for ent in self.talked_rejected if keep_entity(ent)] self.user_mentioned_untalked = [ent for ent in self.user_mentioned_untalked if keep_entity(ent)] @@ -393,6 +422,8 @@ def reduce_size(self, max_size: int): # Make a set (no duplicates) of all the WikiEntities stored in this EntityTrackerState entity_set = set() entity_set.add(self.cur_entity) + entity_set.update(self.talked_unfinished) + entity_set.update(self.able_to_takeover_entities) entity_set.update(self.talked_finished) entity_set.update(self.talked_rejected) entity_set.update(self.user_mentioned_untalked) @@ -408,6 +439,8 @@ def replace_ent(ent: Optional[WikiEntity]): return None return entname2ent[ent.name] self.cur_entity = replace_ent(self.cur_entity) + self.talked_unfinished = [replace_ent(ent) for ent in self.talked_unfinished] + self.able_to_takeover_entities = [replace_ent(ent) for ent in self.able_to_takeover_entities] self.talked_finished = [replace_ent(ent) for ent in self.talked_finished] self.talked_rejected = [replace_ent(ent) for ent in self.talked_rejected] self.user_mentioned_untalked = [replace_ent(ent) for ent in self.user_mentioned_untalked] diff --git a/chirpy/core/logging_formatting.py b/chirpy/core/logging_formatting.py index 087bc32..b89df7d 100644 --- a/chirpy/core/logging_formatting.py +++ b/chirpy/core/logging_formatting.py @@ -2,23 +2,52 @@ from typing import Optional from colorama import Fore, Back +from pathlib import Path + +def get_active_branch_name(): + git_dir = Path(".") / ".git" + if git_dir.is_dir(): + head_dir = git_dir / "HEAD" + with head_dir.open("r") as f: content = f.read().splitlines() + + for line in content: + if line[0:4] == "ref:": + return line.partition("refs/heads/")[2] + else: # for integ. testing, we don't copy .git/ to the instance + return [] + LINEBREAK = '' -# The key in this dict must match the 'name' given to the component in baseline_bot.py (case-sensitive) -# The path_strings are strings we'll search for (case-sensitive) in the path of the file that does the log message +# The key in this dict must match the 'name' given to the component in baseline_bot.py (case-insensitive) +# The path_strings are strings we'll search for (case-insensitive) in the path of the file that does the log message # You can comment out parts of this dict and add your own components to make it easier to only see what you're working on +# See https://rich.readthedocs.io/en/stable/appendix/colors.html for list of rich colors +# See https://github.com/willmcgugan/rich/blob/master/rich/_emoji_codes.py for emoji codes COLOR_SETTINGS = { - 'WIKI': {'color': Fore.MAGENTA, 'path_strings': ['wiki']}, - 'MOVIES': {'color': Fore.GREEN, 'path_strings': ['movies']}, - # 'NEWS': {'color': Fore.CYAN, 'path_strings': ['news']}, - 'ACKNOWLEDGMENT': {'color': Fore.CYAN, 'path_strings': ['acknowledgment']}, - # 'LAUNCH': {'color': Fore.LIGHTMAGENTA_EX, 'path_strings': ['launch']}, - 'CATEGORIES': {'color': Fore.YELLOW, 'path_strings': ['categories']}, - 'NEURAL_CHAT': {'color': Fore.LIGHTMAGENTA_EX, 'path_strings': ['neural_chat']}, - 'entity_linker': {'color': Fore.LIGHTCYAN_EX, 'path_strings': ['entity_linker']}, - 'entity_tracker': {'color': Fore.LIGHTYELLOW_EX, 'path_strings': ['entity_tracker']}, - 'experiments': {'color': Fore.LIGHTGREEN_EX, 'path_strings': ['experiments']}, - 'navigational_intent': {'color': Fore.LIGHTMAGENTA_EX, 'path_strings': ['navigational_intent']} + 'ACKNOWLEDGMENT': {'color': Fore.CYAN, 'rich_color': '#0AAB42', + 'emoji': ':white_heavy_check_mark:', 'path_strings': ['acknowledgment']}, + 'ALEXA_COMMANDS': {'emoji': ':speaking_head_in_silhouette:', 'path_strings': ['alexa_commands']}, + 'ALIENS': {'rich_color': '#1EA8B3', 'emoji': ':alien:'}, + 'CATEGORIES': {'rich_color': '#15EBCE', 'emoji': ':newspaper:', 'path_strings': ['categories']}, + 'CORONAVIRUS': {'rich_color': '#F70C6E', 'emoji': ':face_with_medical_mask:'}, + 'FOOD': {'rich_color': '#97F20F', 'emoji': ':sushi:'}, + 'LAUNCH': {'emoji': ':checkered_flag:', 'path_strings': ['launch']}, + 'MOVIES': {'rich_color': '#F0D718', 'emoji': ':movie_camera:', 'path_strings': ['movies']}, + 'MUSIC': {'rich_color': '#0586FF', 'emoji': ':musical_notes:'}, + 'NEURAL_CHAT': {'rich_color': '#0EE827', 'emoji': ':brain:', 'path_strings': ['neural_chat']}, + 'NEWS': {'rich_color': '#1C64FF', 'emoji': ':newspaper:'}, + 'OFFENSIVE_USER': {'rich_color': '#EB5215', 'emoji': ':prohibited:'}, + 'ONE_TURN_HACK': {'rich_color': '#88B0B3', 'emoji': ':hammer:'}, + 'OPINION': {'rich_color': '#D011ED', 'emoji': ':thinking_face:'}, + 'PERSONAL_ISSUES': {'rich_color': '#BC3BEB', 'emoji': ':slightly_frowning_face:'}, + 'SPORTS': {'rich_color': '#EB8715', 'emoji': ':football:'}, + 'WIKI': {'rich_color': '#42C2F5', 'emoji': ':books:'}, + 'TRANSITION': {'rich_color': '#5FD700', 'emoji': ':soon_arrow:'}, + 'REOPEN': {'rich_color': '##5F00FF', 'emoji': ':door:'}, + 'entity_linker': {'color': Fore.LIGHTCYAN_EX, 'rich_color': '#0BC3E3', 'path_strings': ['entity_linker']}, + 'entity_tracker': {'color': Fore.LIGHTYELLOW_EX, 'rich_color': '#DB960D', 'path_strings': ['entity_tracker']}, + 'experiments': {'color': Fore.LIGHTGREEN_EX, 'rich_color': '#CADB0D', 'path_strings': ['experiments']}, + 'navigational_intent': {'color': Fore.LIGHTMAGENTA_EX, 'rich_color': '#DB0D93', 'path_strings': ['navigational_intent']} } LOG_FORMAT = '[%(levelname)s] [%(asctime)s] [fn_vers: {function_version}] [session_id: {session_id}] [%(pathname)s:%(lineno)d]\n%(message)s\n' @@ -33,8 +62,14 @@ def colored(str, fore=None, back=None, include_reset=True): new_str = '{}{}{}'.format(back, new_str, Back.RESET if include_reset else '') return new_str +def get_rich_color_for_rg(rg_name): + for component_name, settings in COLOR_SETTINGS.items(): + if component_name.lower() == rg_name.lower() and settings.get('rich_color'): + color = settings['rich_color'] + return f"[{color}]{rg_name}[/{color}]" + return rg_name -def get_line_color(line): +def get_line_color(line, branch_name): """ Given a line of logging (which is one line of a multiline log message), searches for component names at the beginning of the line. If one is found, returns its color. @@ -42,7 +77,9 @@ def get_line_color(line): first_part_line = line.strip().split()[0] for component_name, settings in COLOR_SETTINGS.items(): if component_name in first_part_line: - return settings['color'] + return settings.get('color') + if any(b.lower() in first_part_line.lower() for b in branch_name): + return Fore.BLUE return None @@ -62,7 +99,7 @@ def get_line_key(idx: int): class ChirpyFormatter(logging.Formatter): """ - A custom formatter that formats linebreaks and color according to logger_settings, and the context of each message. + A color formatter that formats linebreaks and color according to logger_settings, and the context of each message. Based on this: https://stackoverflow.com/a/14859558 """ @@ -72,6 +109,10 @@ def __init__(self, allow_multiline: bool, use_color: bool, session_id: Optional[ self.use_color = use_color self.session_id = session_id self.function_version = function_version + if self.use_color: + branch_name = get_active_branch_name() + branch_name = ''.join([x if x.isalpha() else ' ' for x in branch_name]) + self.branch_name = branch_name.split() self.update_format() def update_format(self): @@ -137,15 +178,19 @@ def format_color(self, record): lines = record.msg.split('\n') for idx, line in enumerate(lines): setattr(record, get_line_key(idx), line) # e.g. record['line_5'] -> the text of the 5th line of logging - line_colors = [get_line_color(line) for line in lines] # get the color for each line + line_colors = [get_line_color(line, self.branch_name) for line in lines] # get the color for each line self._style._fmt = self.fmt.replace('%(message)s', linecolored_msg_fmt(line_colors)) # this format string has keys for line_1, line_2, etc, along with line-specific colors # If the filepath of the calling function contains a path string for a colored component, return its color else: for component, settings in COLOR_SETTINGS.items(): - for path_string in settings['path_strings']: - if path_string in record.pathname: - self._style._fmt = colored(self.fmt, fore=settings['color']) + if settings.get('path_strings'): + for path_string in settings['path_strings']: + if path_string in record.pathname: + self._style._fmt = colored(self.fmt, fore=settings['color']) + continue + if any(b in record.pathname for b in self.branch_name): + self._style._fmt = colored(self.fmt, fore=Fore.BLUE) # Use the formatter class to do the formatting (with a possibly modified format) result = logging.Formatter.format(self, record) diff --git a/chirpy/core/logging_rich.py b/chirpy/core/logging_rich.py new file mode 100644 index 0000000..4557a64 --- /dev/null +++ b/chirpy/core/logging_rich.py @@ -0,0 +1,331 @@ +import logging +from datetime import datetime +from logging import Handler, LogRecord +from pathlib import Path +from collections import Iterable +from typing import ClassVar, Iterable, List, Optional, Type, TYPE_CHECKING, Union, Callable +import os +import rich + +from rich import get_console +from rich._log_render import LogRender, FormatTimeCallable +from rich.containers import Renderables +from rich.console import Console, ConsoleRenderable, RenderableType +from rich.highlighter import Highlighter, ReprHighlighter +from rich.text import Text, TextType +from rich.traceback import Traceback +from rich.logging import RichHandler + +from rich.table import Table + +from chirpy.core.logging_formatting import COLOR_SETTINGS + +PATH_WIDTH = 25 + +LEVEL_STYLES = {"primary_info": "dim", + "error": "bold red on bright_yellow"} + +LEVEL_LINE_COLORS = {"error": "red"} + +COBOT_HOME = os.environ.get('COBOT_HOME', os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + +def get_rich_color(text): + for component_name, settings in COLOR_SETTINGS.items(): + if component_name.lower() in text.lower(): + return settings.get('rich_color') + if 'path_strings' in settings: + for path_string in settings['path_strings']: + if path_string.lower() in text.lower(): + return settings.get('rich_color') + return None + +def add_emoji(text, check_text = None): + if not check_text: + check_text = text + for component_name, settings in COLOR_SETTINGS.items(): + if component_name in check_text and settings.get('emoji'): + return settings['emoji'] + ' ' + text + if 'path_strings' in settings: + for path_string in settings['path_strings']: + if path_string in check_text and settings.get('emoji'): + return settings['emoji'] + ' ' + text + return text + +class ChirpyLogRender(LogRender): + def __call__( + self, + console: "Console", + renderables: Iterable["ConsoleRenderable"], + log_time: datetime = None, + time_format: Union[str, FormatTimeCallable] = None, + level: TextType = "", + path: str = None, + line_no: int = None, + link_path: str = None, + path_color: str = None, + ) -> "Table": + output = Table.grid(padding=(0, 1)) + output.expand = True + if self.show_level: + output.add_column(width=self.level_width) + if self.show_path and path: + output.add_column(width=PATH_WIDTH) #style="dim", + output.add_column(ratio=1, style="log.message", overflow="fold") + if self.show_time: + output.add_column(style="log.time") + row: List["RenderableType"] = [] + if self.show_level: + row.append(level[:3]) + if self.show_path and path: + path_text = Text.from_markup(path) + if line_no: + if len(path_text) > PATH_WIDTH - len(str(line_no)) - 2: + path_text.truncate(PATH_WIDTH - len(str(line_no)) - 2) + path_text.append("…") + path_text.append(f":{line_no}") + path_text.stylize(path_color) + row.append(path_text) + + row.append(Renderables(renderables)) + if self.show_time: + log_time = log_time or console.get_datetime() + time_format = time_format or self.time_format + if callable(time_format): + log_time_display = time_format(log_time) + else: + log_time_display = Text(log_time.strftime(time_format)[:-4] + ']') + if log_time_display == self._last_time and self.omit_repeated_times: + row.append(Text(" " * len(log_time_display))) + else: + row.append(log_time_display) + self._last_time = log_time_display + + output.add_row(*row) + return output + + +class ChirpyHandler(RichHandler): + DICT_OPEN_TAG: str = "[dict]\n" + DICT_CLOSE_TAG: str = "\n[/dict]" + + def __init__( + self, + level: Union[int, str] = logging.NOTSET, + console: Console = None, + *, + show_time: bool = True, + omit_repeated_times: bool = True, + show_level: bool = True, + show_path: bool = True, + enable_link_path: bool = True, + highlighter: Highlighter = None, + markup: bool = False, + rich_tracebacks: bool = False, + tracebacks_width: Optional[int] = None, + tracebacks_extra_lines: int = 3, + tracebacks_theme: Optional[str] = None, + tracebacks_word_wrap: bool = True, + tracebacks_show_locals: bool = False, + locals_max_length: int = 10, + locals_max_string: int = 80, + log_time_format: Union[str, FormatTimeCallable] = "[%x %X]", + filter_by_rg: str = None, + disable_annotation: bool = False, + ) -> None: + super().__init__( + level=level, + console=console, + show_time=show_time, + omit_repeated_times=omit_repeated_times, + show_level=show_level, + show_path=show_path, + enable_link_path=enable_link_path, + highlighter=highlighter, + markup=markup, + rich_tracebacks=rich_tracebacks, + tracebacks_width=tracebacks_width, + tracebacks_extra_lines=tracebacks_extra_lines, + tracebacks_theme=tracebacks_theme, + tracebacks_word_wrap=tracebacks_word_wrap, + tracebacks_show_locals=tracebacks_show_locals, + locals_max_length=locals_max_length, + locals_max_string=locals_max_string, + log_time_format=log_time_format, + ) + self._log_render = ChirpyLogRender( + show_time=show_time, + show_level=show_level, + show_path=show_path, + time_format=log_time_format, + omit_repeated_times=omit_repeated_times, + level_width=None, + ) + if filter_by_rg: + valid_rg_filenames = [f.name.lower() for f in os.scandir(os.path.join(COBOT_HOME, "chirpy/response_generators")) if f.is_dir()] + filter_by_rg = filter_by_rg.lower() + assert filter_by_rg in valid_rg_filenames, f"{filter_by_rg} does not specify a valid RG filename (must be a folder in chirpy/response_generators)" + self.filter_by_rg = filter_by_rg + self.disable_annotation = disable_annotation + + def process_dictionary(self, dict_text: str) -> "ConsoleRenderable": + lines = dict_text.split('\n') + pairs = [line.split('\u00a0' * 5) for line in lines] + assert all(len(p) == 2 for p in pairs) + grid = Table.grid(expand=True, padding=(0, 3)) + grid.add_column(justify="left", width=25) + grid.add_column(ratio=1) + for pair in pairs: + pair = [p.strip().strip("'") for p in pair] + name, value = pair + text_color = get_rich_color(name) + name = add_emoji(name) + if text_color: + grid.add_row(Text.from_markup(name, style=text_color), value) + else: + grid.add_row(name, value) + return grid + + def render_message(self, record: LogRecord, message: str) -> List["ConsoleRenderable"]: + """Render message text in to Text. + + record (LogRecord): logging Record. + message (str): String cotaining log message. + + Returns: + ConsoleRenderable: Renderable to display log message. + """ + use_markup = ( + getattr(record, "markup") if hasattr(record, "markup") else self.markup + ) + message_texts = [] + if record.levelname.lower() in LEVEL_LINE_COLORS: + color = LEVEL_LINE_COLORS[record.levelname.lower()] + message = "[" + color + "]" + message + message = message.replace('\n', "[/" + color + "]\n", 1) + if self.DICT_OPEN_TAG in message: + start = message.find(self.DICT_OPEN_TAG) + end = message.find(self.DICT_CLOSE_TAG) + dict_text = message[start + len(self.DICT_OPEN_TAG):end] + message_one = message[:start] + message_two = message[end + len(self.DICT_CLOSE_TAG):] + text_color = None + message_texts.append(Text.from_markup(message_one) if use_markup else Text(message_one)) + message_texts.append(self.process_dictionary(dict_text)) + message_texts.append(Text.from_markup(message_two) if use_markup else Text(message_two)) + else: + message_texts.append(Text.from_markup(message) if use_markup else Text(message)) + for message_text in message_texts: + if isinstance(message_text, Text): + if self.highlighter: + message_text = self.highlighter(message_text) + if self.KEYWORDS: + message_text.highlight_words(self.KEYWORDS, "logging.keyword") + return message_texts + + def get_level_text(self, record: LogRecord) -> Text: + """Get the level name from the record. + + Args: + record (LogRecord): LogRecord instance. + + Returns: + Text: A tuple of the style and level name. + """ + level_name = record.levelname + level_text = Text.styled( + level_name[:3].ljust(8).capitalize(), f"logging.level.{level_name.lower()}" + ) + if level_name.lower() in LEVEL_STYLES: + level_text.stylize(LEVEL_STYLES[level_name.lower()]) + return level_text + + def emit(self, record: LogRecord) -> None: + """Invoked by logging.""" + message = self.format(record) + traceback = None + if ( + self.rich_tracebacks + and record.exc_info + and record.exc_info != (None, None, None) + ): + exc_type, exc_value, exc_traceback = record.exc_info + assert exc_type is not None + assert exc_value is not None + traceback = Traceback.from_exception( + exc_type, + exc_value, + exc_traceback, + width=self.tracebacks_width, + extra_lines=self.tracebacks_extra_lines, + theme=self.tracebacks_theme, + word_wrap=self.tracebacks_word_wrap, + show_locals=self.tracebacks_show_locals, + locals_max_length=self.locals_max_length, + locals_max_string=self.locals_max_string, + ) + message = record.getMessage() + if self.formatter: + record.message = record.getMessage() + formatter = self.formatter + if hasattr(formatter, "usesTime") and formatter.usesTime(): # type: ignore + record.asctime = formatter.formatTime(record, formatter.datefmt) + message = formatter.formatMessage(record) + + if self.should_show(record): + message_renderable = self.render_message(record, message) + log_renderable = self.render( + record=record, traceback=traceback, message_renderable=message_renderable + ) + self.console.print(log_renderable) + + def should_show(self, record): + if record.levelname.lower() in ['error', 'warning']: + return True + path = record.pathname + if 'response_generators' not in path.lower(): + return not self.disable_annotation + if self.filter_by_rg is None: + return True + return self.filter_by_rg in path.lower() + + def render( + self, + *, + record: LogRecord, + traceback: Optional[Traceback], + message_renderable: "ConsoleRenderable", + ) -> "ConsoleRenderable": + """Render log for display. + + Args: + record (LogRecord): logging Record. + traceback (Optional[Traceback]): Traceback instance or None for no Traceback. + message_renderable (ConsoleRenderable): Renderable (typically Text) containing log message contents. + + Returns: + ConsoleRenderable: Renderable to display log. + """ + path_color = get_rich_color(record.pathname) + path = Path(record.pathname).name + path = add_emoji(path, record.pathname) + if record.levelname.lower() in LEVEL_LINE_COLORS: + path_color = LEVEL_LINE_COLORS[record.levelname.lower()] + level = self.get_level_text(record) + time_format = None if self.formatter is None else self.formatter.datefmt + log_time = datetime.fromtimestamp(record.created) + + if traceback: + message_renderable.append(traceback) + + log_renderable = self._log_render( + self.console, + message_renderable, + log_time=log_time, + time_format=time_format, + level=level, + path=path, + line_no=record.lineno, + link_path=record.pathname if self.enable_link_path else None, + path_color=path_color, + ) + return log_renderable \ No newline at end of file diff --git a/chirpy/core/logging_utils.py b/chirpy/core/logging_utils.py index e460eb7..9d4f1e8 100644 --- a/chirpy/core/logging_utils.py +++ b/chirpy/core/logging_utils.py @@ -1,5 +1,6 @@ """ -This file contains functions to create and configure the chirpylogger +This file contains functions to create and configure the chirpylogger, which is a single simple logger to replace +the more complicated LoggerFactory that came with Cobot. """ import logging @@ -8,6 +9,8 @@ from dataclasses import dataclass from typing import Optional from chirpy.core.logging_formatting import ChirpyFormatter +from chirpy.core.logging_rich import ChirpyHandler +from rich.highlighter import NullHighlighter PRIMARY_INFO_NUM = logging.INFO + 5 # between INFO and WARNING @@ -22,18 +25,24 @@ class LoggerSettings: logtoscreen_allow_multiline: bool # If true, log-to-screen messages contain \n. If false, all the \n are replaced with integ_test: bool # If True, we setup the logger in a special way to work with nosetests remove_root_handlers: bool # If True, we remove all other handlers on the root logger + allow_rich_formatting: bool = True + filter_by_rg: str = None + disable_annotation: bool = False # AWS adds a LambdaLoggerHandler to the root handler, which causes duplicate logging because we have our customized # StreamHandler on the root logger too. So we set remove_root_handlers=True to remove the LambdaLoggerHandler. # See here: https://stackoverflow.com/questions/50909824/getting-logs-twice-in-aws-lambda-function -PROD_LOGGER_SETTINGS = LoggerSettings(logtoscreen_level=logging.INFO, +PROD_LOGGER_SETTINGS = LoggerSettings(logtoscreen_level=logging.INFO + 5, logtoscreen_usecolor=True, logtofile_level=None, logtofile_path=None, - logtoscreen_allow_multiline=False, + logtoscreen_allow_multiline=True, integ_test=False, - remove_root_handlers=True) + remove_root_handlers=True, + allow_rich_formatting=True, + filter_by_rg=None, + disable_annotation=False) def setup_logger(logger_settings, session_id=None): @@ -85,11 +94,23 @@ def setup_logger(logger_settings, session_id=None): chirpy_logger.setLevel(logging.DEBUG) # Create the stream handler and attach it to the root logger - stream_handler = logging.StreamHandler(sys.stdout) - stream_handler.setLevel(logger_settings.logtoscreen_level) - stream_formatter = ChirpyFormatter(allow_multiline=logger_settings.logtoscreen_allow_multiline, use_color=logger_settings.logtoscreen_usecolor, session_id=session_id) - stream_handler.setFormatter(stream_formatter) - root_logger.addHandler(stream_handler) + print("allow_multiline = ", logger_settings.logtoscreen_allow_multiline ) + print("rich formatting = ", logger_settings.allow_rich_formatting) + if logger_settings.logtoscreen_allow_multiline and logger_settings.allow_rich_formatting: + root_logger.addHandler(ChirpyHandler(log_time_format="[%H:%M:%S.%f]", + level=logger_settings.logtoscreen_level, + markup=True, + highlighter=NullHighlighter(), + filter_by_rg=logger_settings.filter_by_rg, + disable_annotation=logger_settings.disable_annotation)) + else: + # Use the stream handler if no multi-line to not mess up production logs + stream_handler = logging.StreamHandler(sys.stdout) + stream_handler.setLevel(logger_settings.logtoscreen_level) + stream_formatter = ChirpyFormatter(allow_multiline=logger_settings.logtoscreen_allow_multiline, use_color=logger_settings.logtoscreen_usecolor, session_id=session_id) + stream_handler.setFormatter(stream_formatter) + root_logger.addHandler(stream_handler) + #root_logger.addHandler(RichHandler(log_time_format="[%H:%M:%S]", level=logger_settings.logtoscreen_level, markup=True)) # Create the file handler and attach it to the root logger if logger_settings.logtofile_path: @@ -102,7 +123,7 @@ def setup_logger(logger_settings, session_id=None): # Mark that the root logger has the chirpy handlers attached root_logger.chirpy_handlers = True - # Add the custom PRIMARY_INFO level to chirpy logger + # Add the color PRIMARY_INFO level to chirpy logger add_new_level(chirpy_logger, 'PRIMARY_INFO', PRIMARY_INFO_NUM) return chirpy_logger @@ -158,4 +179,4 @@ def update_logger(session_id, function_version): for handler in root_logger.handlers: if isinstance(handler.formatter, ChirpyFormatter): handler.formatter.update_session_id(session_id) - handler.formatter.update_function_version(function_version) + handler.formatter.update_function_version(function_version) \ No newline at end of file diff --git a/chirpy/core/response_generator/response_generator.py b/chirpy/core/response_generator/response_generator.py index 6063aa9..8c8c47c 100644 --- a/chirpy/core/response_generator/response_generator.py +++ b/chirpy/core/response_generator/response_generator.py @@ -24,6 +24,7 @@ from chirpy.response_generators.music.utils import WikiEntityInterface from concurrent import futures +import copy logger = logging.getLogger('chirpylogger') @@ -41,7 +42,7 @@ def __init__(self, disallow_start_from=None, can_give_prompts=False, state_constructor=None, - conditional_state_constructor=None + conditional_state_constructor=None, ): """Creates a new Response Generator. @@ -93,7 +94,8 @@ def update_state_if_chosen(self, state, conditional_state): if response_types is not None: state.response_types = construct_response_types_tuple(response_types) - if conditional_state is None: return state + if conditional_state is None: + return state if conditional_state: for attr in dir(conditional_state): @@ -101,15 +103,20 @@ def update_state_if_chosen(self, state, conditional_state): val = getattr(conditional_state, attr) if val != NO_UPDATE: setattr(state, attr, val) state.num_turns_in_rg += 1 + return state - def update_state_if_not_chosen(self, state, conditional_state): + def update_state_if_not_chosen(self, state, conditional_state, rg_was_taken_over=False): """ By default, this sets the prev_treelet_str and next_treelet_str to '' and resets num_turns_in_rg to 0. Response types are also saved. No other attributes are updated. All other attributes in ConditionalState are set to NO-UPDATE """ + if rg_was_taken_over: + state.archived_state = copy.deepcopy(state) + logging.info(f"Save current state as archived_state for conversation to be resumed: {state.archived_state}") + response_types = self.get_cache(f'{self.name}_response_types') if response_types is not None: state.response_types = construct_response_types_tuple(response_types) @@ -285,6 +292,9 @@ def get_current_entity(self, initiated_this_turn=False): else: return self.state_manager.current_state.entity_tracker.cur_entity + def get_most_recent_able_to_takeover_entity(self): + return self.state_manager.current_state.entity_tracker.able_to_takeover_entities[-1] + def get_entity_tracker(self): return self.state_manager.current_state.entity_tracker @@ -861,7 +871,7 @@ def get_last_rg_in_control(self) -> Optional[str]: return self.state_manager.last_state.selected_response_rg - def get_response(self, state) -> ResponseGeneratorResult: + def get_response(self, state, rg_was_taken_over=False) -> ResponseGeneratorResult: response_types = self.identify_response_types(self.utterance) logger.primary_info(f"{self.name} identified response_types: {response_types}") self.state = state @@ -915,19 +925,24 @@ def get_response(self, state) -> ResponseGeneratorResult: if not is_continuing_conversation: # allow the first branch to divert here logger.primary_info(f"{self.name} is not currently active, so checking if it should activate") - activation_check_fns = { (lambda: self.get_last_active_rg() in self.disallow_start_from): self.get_fallback_result, (lambda: True): self.handle_direct_navigational_intent, + (lambda: (self.last_rg_willing_to_handover_control() and self.exist_able_to_takeover_entities())): self.get_takeover_response, (lambda: True): self.handle_current_entity, (lambda: True): self.get_intro_treelet_response, (lambda: True): self.handle_custom_activation_checks, } + logging.debug(f"DEBUG HANDOVER {self.last_rg_willing_to_handover_control()}, {self.exist_able_to_takeover_entities()}") + for activation_condition, activation_check_fn in activation_check_fns.items(): if activation_condition(): response = activation_check_fn() - if response: return self.possibly_augment_with_prompt(response) + + if response: + return self.possibly_augment_with_prompt(response) + response = self.handle_default_post_checks() if response: @@ -935,6 +950,73 @@ def get_response(self, state) -> ResponseGeneratorResult: return self.get_fallback_result() + def last_rg_willing_to_handover_control(self): + last_active_rg_prompt = self.state_manager.last_state_response + if last_active_rg_prompt: + return last_active_rg_prompt.last_rg_willing_to_handover_control + else: + return False + + def exist_able_to_takeover_entities(self): + return len(self.state_manager.current_state.entity_tracker.able_to_takeover_entities) != 0 + + def get_takeover_response(self): + logging.info(f"{self.name} null get_takeover_response") + return None + + def takeover_rg_willing_to_handback_control(self): + last_active_rg_prompt = self.state_manager.last_state_response + if last_active_rg_prompt: + return last_active_rg_prompt.takeover_rg_willing_to_handback_control + else: + return False + + def get_resuming_statement(self, state) -> ResponseGeneratorResult: + logging.info(f"{self.name} null get_resuming_statement") + return self.emptyPrompt() + + def augment_resuming_statement(self, resuming_statement_first_treelet): + resuming_conversation_second_treelet_str = resuming_statement_first_treelet.resuming_conversation_next_treelet + logger.debug(f"The prompt treelet for resuming conversation is {resuming_conversation_second_treelet_str}") + resuming_conversation_second_treelet = self.treelets[resuming_conversation_second_treelet_str] + resuming_prompt_second_treelet = resuming_conversation_second_treelet.get_prompt() + logger.debug(f"The prompt for resuming conversation is {resuming_prompt_second_treelet}") + if resuming_prompt_second_treelet: + resuming_statement_first_treelet.text = f"{resuming_statement_first_treelet.text} {resuming_prompt_second_treelet.text}" + resuming_statement_first_treelet.conditional_state.next_treelet_str = resuming_conversation_second_treelet_str + for attr_to_copy in ['state', 'cur_entity', 'expected_type', 'answer_type', + 'last_rg_willing_to_handover_control', 'takeover_rg_willing_to_handback_control']: + attr_template = getattr(resuming_prompt_second_treelet, attr_to_copy) + setattr(resuming_statement_first_treelet, attr_to_copy, attr_template) + resuming_statement_first_treelet.resuming_conversation_next_treelet = None + return resuming_statement_first_treelet + + def resume_conversation(self): + logger.debug( + f"The archived_state for resuming conversation is {self.state_manager.current_state.response_generator_states[self.name].archived_state} in {self.name}") + archived_state = self.state_manager.current_state.response_generator_states[self.name].archived_state + self.state = archived_state + logger.error(f"The state of {self.name} after retrieving archived_state is: {self.state}") + + first_treelet_str = self.state.next_treelet_str + assert first_treelet_str in self.treelets + first_treelet = self.treelets[first_treelet_str] + resuming_statement_first_treelet = first_treelet.get_resuming_statement() + logger.info(f"The resuming statement generated from the current treelet is {resuming_statement_first_treelet}") + + resuming_conversation = self.augment_resuming_statement(resuming_statement_first_treelet) + logger.info(f"The resuming statement after augmented with a prompt is {resuming_conversation}") + + return resuming_conversation + + + def get_prompt_wrapper(self, state, rg_to_resume=False): + if self.takeover_rg_willing_to_handback_control(): + if rg_to_resume: + return self.resume_conversation() + else: + return self.get_prompt(state) + def possibly_augment_with_prompt(self, response): """ @@ -970,8 +1052,11 @@ def continue_conversation(self, response_types) -> Optional[ResponseGeneratorRes next_treelet_str = self.state.next_treelet_str next_treelet = None - response_priority = ResponsePriority.STRONG_CONTINUE - + if self.last_rg_willing_to_handover_control(): # we talked last turn and decided to handover... + response_priority = ResponsePriority.WEAK_CONTINUE + else: + response_priority = ResponsePriority.STRONG_CONTINUE + logger.error(f"In continue_conversation, self.state is {self.state}, next_treelet_str is {next_treelet_str}, priority is {response_priority}") if next_treelet_str is None: return self.emptyResult() # continue from some other RG elif next_treelet_str == '': @@ -1013,7 +1098,6 @@ def continue_conversation(self, response_types) -> Optional[ResponseGeneratorRes logger.info(f"Continuing conversation from {next_treelet_str} for {self.name}") assert next_treelet_str in self.treelets next_treelet = self.treelets[next_treelet_str] - response_priority = ResponsePriority.STRONG_CONTINUE if next_treelet is not None: response = next_treelet.get_response(response_priority, ) diff --git a/chirpy/core/response_generator/state.py b/chirpy/core/response_generator/state.py index 01d3080..6c451cf 100644 --- a/chirpy/core/response_generator/state.py +++ b/chirpy/core/response_generator/state.py @@ -3,6 +3,8 @@ from chirpy.core.response_generator.response_type import ResponseType +from chirpy.core.entity_linker.entity_linker_classes import WikiEntity + import logging logger = logging.getLogger('chirpylogger') @@ -22,12 +24,18 @@ class BaseState: next_treelet_str: Optional[str] = '' response_types: Tuple[str] = () num_turns_in_rg: int = 0 + archived_state: "BaseState" = None + rg_that_was_taken_over: str = None + takeover_entity: WikiEntity = None @dataclass class BaseConditionalState: prev_treelet_str: str = '' next_treelet_str: Optional[str] = '' response_types: Tuple[str] = NO_UPDATE + archived_state: "BaseState" = NO_UPDATE + rg_that_was_taken_over: str = NO_UPDATE + takeover_entity: WikiEntity = NO_UPDATE def construct_response_types_tuple(response_types): return tuple([str(x) for x in response_types]) diff --git a/chirpy/core/response_generator/treelet.py b/chirpy/core/response_generator/treelet.py index 805bf9b..13337da 100644 --- a/chirpy/core/response_generator/treelet.py +++ b/chirpy/core/response_generator/treelet.py @@ -50,6 +50,9 @@ def get_current_state(self): def get_current_entity(self, initiated_this_turn=False): return self.rg.get_current_entity(initiated_this_turn=initiated_this_turn) + def get_most_recent_able_to_takeover_entity(self): + return self.rg.get_most_recent_able_to_takeover_entity() + def get_sentiment(self): return self.rg.get_sentiment() diff --git a/chirpy/core/response_generator_datatypes.py b/chirpy/core/response_generator_datatypes.py index 3f12221..510daa0 100644 --- a/chirpy/core/response_generator_datatypes.py +++ b/chirpy/core/response_generator_datatypes.py @@ -33,7 +33,12 @@ def __init__(self, smooth_handoff: Optional[SmoothHandoff] = None, conditional_state=None, tiebreak_priority=None, - no_transition=False): + no_transition=False, + last_rg_willing_to_handover_control=False, + rg_that_was_taken_over =None, + takeover_entity=None, + takeover_rg_willing_to_handback_control=False + ): """ :param text: text of the response :param priority: priority of the response @@ -98,6 +103,10 @@ def __init__(self, self.conditional_state = conditional_state self.tiebreak_priority = tiebreak_priority self.no_transition = no_transition + self.last_rg_willing_to_handover_control = last_rg_willing_to_handover_control + self.rg_that_was_taken_over = rg_that_was_taken_over + self.takeover_entity = takeover_entity + self.takeover_rg_willing_to_handback_control = takeover_rg_willing_to_handback_control def reduce_size(self, max_size:int = None): """Gracefully degrade by removing non essential attributes. @@ -124,7 +133,13 @@ def __init__(self, cur_entity: Optional[WikiEntity], expected_type: Optional[EntityGroup] = None, conditional_state=None, - answer_type: AnswerType = AnswerType.QUESTION_SELFHANDLING): + answer_type: AnswerType = AnswerType.QUESTION_SELFHANDLING, + last_rg_willing_to_handover_control=False, + rg_that_was_taken_over =None, + takeover_entity =None, + takeover_rg_willing_to_handback_control=False, + resuming_conversation_next_treelet=None + ): """ :param text: text of the response :param prompt_type: the type of response being given, typically CONTEXTUAL or GENERIC @@ -163,6 +178,11 @@ def __init__(self, self.state = state self.conditional_state = conditional_state self.answer_type = answer_type + self.last_rg_willing_to_handover_control = last_rg_willing_to_handover_control + self.rg_that_was_taken_over = rg_that_was_taken_over + self.takeover_entity = takeover_entity + self.takeover_rg_willing_to_handback_control = takeover_rg_willing_to_handback_control + self.resuming_conversation_next_treelet = resuming_conversation_next_treelet def __repr__(self): return 'PromptResult' + str(self.__dict__) diff --git a/chirpy/core/state.py b/chirpy/core/state.py index b9cb559..310c35e 100644 --- a/chirpy/core/state.py +++ b/chirpy/core/state.py @@ -78,6 +78,7 @@ def update_from_last_state(self, last_state): self.entity_tracker.init_for_new_turn() self.experiments = last_state.experiments self.turn_num = last_state.turn_num + 1 + try: self.turns_since_last_active = last_state.turns_since_last_active except AttributeError: diff --git a/chirpy/core/state_manager.py b/chirpy/core/state_manager.py index f211297..c535874 100644 --- a/chirpy/core/state_manager.py +++ b/chirpy/core/state_manager.py @@ -30,3 +30,5 @@ def last_state_response(self): if not self.last_state: return None if hasattr(self.last_state, 'prompt_results'): return self.last_state.prompt_results[self.last_state.active_rg] else: return self.last_state.response_results[self.last_state.active_rg] + + diff --git a/chirpy/response_generators/closing_confirmation/closing_confirmation_response_generator.py b/chirpy/response_generators/closing_confirmation/closing_confirmation_response_generator.py index bbca842..e88f417 100644 --- a/chirpy/response_generators/closing_confirmation/closing_confirmation_response_generator.py +++ b/chirpy/response_generators/closing_confirmation/closing_confirmation_response_generator.py @@ -101,7 +101,7 @@ def handle_custom_continuation_checks(self): # If neither matched, allow another RG to handle return self.emptyResult() - def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState]) -> BaseState: + def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState], rg_was_taken_over=False) -> BaseState: state = super().update_state_if_not_chosen(state, conditional_state) state.has_just_asked_to_exit = False return state diff --git a/chirpy/response_generators/food/food_response_generator.py b/chirpy/response_generators/food/food_response_generator.py index 540a31b..b30497c 100644 --- a/chirpy/response_generators/food/food_response_generator.py +++ b/chirpy/response_generators/food/food_response_generator.py @@ -20,6 +20,7 @@ from chirpy.core.offensive_classifier.offensive_classifier import OffensiveClassifier from chirpy.response_generators.food.food_helpers import * + logger = logging.getLogger('chirpylogger') class FoodResponseGenerator(ResponseGenerator): @@ -31,9 +32,11 @@ def __init__(self, state_manager) -> None: self.comment_on_favorite_type_treelet = CommentOnFavoriteTypeTreelet(self) self.ask_favorite_food_treelet = AskFavoriteFoodTreelet(self) self.factoid_treelet = FactoidTreelet(self) + treelets = { treelet.name: treelet for treelet in [self.introductory_treelet, self.open_ended_user_comment_treelet, - self.comment_on_favorite_type_treelet, self.factoid_treelet, self.ask_favorite_food_treelet] + self.comment_on_favorite_type_treelet, self.factoid_treelet, self.ask_favorite_food_treelet + ] } super().__init__(state_manager, treelets=treelets, intent_templates=[], can_give_prompts=True, state_constructor=State, @@ -87,4 +90,4 @@ def get_neural_response(self, prefix=None, allow_questions=False, conditions=Non def get_prompt(self, state): self.state = state self.response_types = self.get_cache(f'{self.name}_response_types') - return self.emptyPrompt() + return self.emptyPrompt() \ No newline at end of file diff --git a/chirpy/response_generators/food/treelets/ask_favorite_food_treelet.py b/chirpy/response_generators/food/treelets/ask_favorite_food_treelet.py index fa3c570..c5b6a1a 100644 --- a/chirpy/response_generators/food/treelets/ask_favorite_food_treelet.py +++ b/chirpy/response_generators/food/treelets/ask_favorite_food_treelet.py @@ -29,5 +29,6 @@ def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): conditional_state=ConditionalState( next_treelet_str="food_introductory_treelet", cur_food=None), - expected_type=ENTITY_GROUPS_FOR_EXPECTED_TYPE.food_related + expected_type=ENTITY_GROUPS_FOR_EXPECTED_TYPE.food_related, + last_rg_willing_to_handover_control=False ) diff --git a/chirpy/response_generators/food/treelets/comment_on_favorite_type_treelet.py b/chirpy/response_generators/food/treelets/comment_on_favorite_type_treelet.py index 6d21c16..af3a02f 100644 --- a/chirpy/response_generators/food/treelets/comment_on_favorite_type_treelet.py +++ b/chirpy/response_generators/food/treelets/comment_on_favorite_type_treelet.py @@ -38,7 +38,8 @@ def get_prompt(self, conditional_state=None): return None return PromptResult(text, PromptType.CONTEXTUAL, state, conditional_state=conditional_state, - cur_entity=entity, answer_type=AnswerType.QUESTION_SELFHANDLING) + cur_entity=entity, answer_type=AnswerType.QUESTION_SELFHANDLING, + ) def get_best_candidate_user_entity(self, utterance, cur_food): def condition_fn(entity_linker_result, linked_span, entity): @@ -93,5 +94,28 @@ def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): cur_entity=entity, conditional_state=ConditionalState( prompt_treelet=self.rg.open_ended_user_comment_treelet.name, - cur_food=cur_food_entity) + cur_food=cur_food_entity), + last_rg_willing_to_handover_control=False ) + + def get_resuming_statement(self, prompt_type=PromptType.FORCE_START, **kwargs): + logger.error(f"GET_STATEMENT_RESPONSE got triggered.") + state, utterance, response_types = self.get_state_utterance_response_types() + entity = self.rg.get_current_entity(initiated_this_turn=False) + cur_food_entity = state.cur_food + cur_food = cur_food_entity.name + cur_talkable_food = cur_food_entity.talkable_name + + if get_custom_question(cur_food) is not None: + custom_question_answer = get_custom_question_answer(cur_food) + text = f"Anyway, personally, when it comes to {cur_talkable_food}, I really like {custom_question_answer}." + else: + other_type = sample_from_type(cur_food) + text = f"Anyway, personally, I really like {other_type}" + + return PromptResult(text=text, prompt_type=prompt_type, state=state, + conditional_state=ConditionalState( + prompt_treelet=self.rg.open_ended_user_comment_treelet.name, + cur_food=cur_food_entity), + cur_entity=entity, + resuming_conversation_next_treelet=self.rg.open_ended_user_comment_treelet.name) \ No newline at end of file diff --git a/chirpy/response_generators/food/treelets/factoid_treelet.py b/chirpy/response_generators/food/treelets/factoid_treelet.py index 217303f..3c39b60 100644 --- a/chirpy/response_generators/food/treelets/factoid_treelet.py +++ b/chirpy/response_generators/food/treelets/factoid_treelet.py @@ -27,7 +27,8 @@ def get_prompt(self, conditional_state=None): conditional_state = ConditionalState(cur_food=cur_food) entity = self.rg.state_manager.current_state.entity_tracker.cur_entity return PromptResult(text=get_factoid(cur_food), prompt_type=PromptType.CONTEXTUAL, - state=state, cur_entity=entity, conditional_state=conditional_state, answer_type=AnswerType.QUESTION_SELFHANDLING) + state=state, cur_entity=entity, conditional_state=conditional_state, answer_type=AnswerType.QUESTION_SELFHANDLING, + ) def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): """ Returns the response. """ @@ -50,5 +51,6 @@ def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): cur_entity=None, conditional_state=ConditionalState( prev_treelet_str=self.name, - cur_food=cur_food - )) + cur_food=cur_food), + last_rg_willing_to_handover_control=True + ) diff --git a/chirpy/response_generators/food/treelets/introductory_treelet.py b/chirpy/response_generators/food/treelets/introductory_treelet.py index d4e0448..20c8155 100644 --- a/chirpy/response_generators/food/treelets/introductory_treelet.py +++ b/chirpy/response_generators/food/treelets/introductory_treelet.py @@ -73,7 +73,8 @@ def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): needs_prompt=False, state=state, cur_entity=entity, conditional_state=ConditionalState(cur_food=entity, - prompt_treelet=prompt_treelet)) + prompt_treelet=prompt_treelet), + last_rg_willing_to_handover_control=True) def get_prompt(self, **kwargs): return None @@ -90,3 +91,4 @@ def get_prompt(self, **kwargs): # cur_treelet_str="get_other_type", # cur_food=entity.name, # response=prompt_text)) + diff --git a/chirpy/response_generators/food/treelets/open_ended_user_comment_treelet.py b/chirpy/response_generators/food/treelets/open_ended_user_comment_treelet.py index df132c4..ab660c5 100644 --- a/chirpy/response_generators/food/treelets/open_ended_user_comment_treelet.py +++ b/chirpy/response_generators/food/treelets/open_ended_user_comment_treelet.py @@ -33,7 +33,7 @@ def get_prompt(self, conditional_state=None): pronoun = infl('them', entity.is_plural) if best_attribute: text = 'What do you think?' else: text = f'What do you like best about {pronoun}?' - return PromptResult(text, PromptType.CONTEXTUAL, state=state, cur_entity=entity, conditional_state=conditional_state) + return PromptResult(text=text, prompt_type=PromptType.CONTEXTUAL, state=state, cur_entity=entity, conditional_state=conditional_state) def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): """ Returns the response. """ @@ -73,5 +73,6 @@ def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): conditional_state=ConditionalState( prev_treelet_str=self.name, prompt_treelet=prompt_treelet, - cur_food=None) + cur_food=None), + last_rg_willing_to_handover_control=False ) diff --git a/chirpy/response_generators/launch/launch_response_generator.py b/chirpy/response_generators/launch/launch_response_generator.py index 8c0fe90..c8a0bd4 100644 --- a/chirpy/response_generators/launch/launch_response_generator.py +++ b/chirpy/response_generators/launch/launch_response_generator.py @@ -50,7 +50,7 @@ def update_state_if_chosen(self, state: State, conditional_state: Optional[Condi # state.asked_name_counter = 1 return state - def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState]) -> BaseState: + def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState], rg_was_taken_over=False) -> BaseState: state = super().update_state_if_not_chosen(state, conditional_state) state.next_treelet_str = None return state diff --git a/chirpy/response_generators/music/music_response_generator.py b/chirpy/response_generators/music/music_response_generator.py index c75fdbd..68bbe46 100644 --- a/chirpy/response_generators/music/music_response_generator.py +++ b/chirpy/response_generators/music/music_response_generator.py @@ -76,7 +76,7 @@ def update_state_if_chosen(self, state, conditional_state): state.discussed_entities.append(state.cur_singer_str) return state - def update_state_if_not_chosen(self, state, conditional_state): + def update_state_if_not_chosen(self, state, conditional_state, rg_was_taken_over=False): state = super().update_state_if_not_chosen(state, conditional_state) return state diff --git a/chirpy/response_generators/neural_chat/neural_chat_response_generator.py b/chirpy/response_generators/neural_chat/neural_chat_response_generator.py index f3f2056..8ac8b18 100644 --- a/chirpy/response_generators/neural_chat/neural_chat_response_generator.py +++ b/chirpy/response_generators/neural_chat/neural_chat_response_generator.py @@ -188,7 +188,7 @@ def update_state_if_chosen(self, state: State, conditional_state: Optional[Condi state.update_if_chosen(conditional_state) return state - def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState]) -> State: + def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState], rg_was_taken_over=False) -> State: logger.primary_info(f"Neural chat state is {state}") if conditional_state is not None: state.update_if_not_chosen(conditional_state) diff --git a/chirpy/response_generators/neural_chat/state.py b/chirpy/response_generators/neural_chat/state.py index d83f9c6..e5045b8 100644 --- a/chirpy/response_generators/neural_chat/state.py +++ b/chirpy/response_generators/neural_chat/state.py @@ -3,6 +3,10 @@ from typing import List, Optional, Set, Tuple from chirpy.core.response_generator.state import NO_UPDATE +from chirpy.core.entity_linker.entity_linker_classes import WikiEntity + +import copy + logger = logging.getLogger('chirpylogger') @dataclass @@ -32,7 +36,8 @@ class ConditionalState(object): def __init__(self, next_treelet: Optional[str] = None, most_recent_treelet: Optional[str] = None, user_utterance: Optional[str] = None, user_labels: List[str] = [], bot_utterance: Optional[str] = None, bot_labels: List[str] = [], - neural_responses: Optional[List[str]] = None, num_topic_shifts: int = 0): + neural_responses: Optional[List[str]] = None, num_topic_shifts: int = 0, + archived_state: "State" = None, rg_that_was_taken_over: str = None, takeover_entity: WikiEntity = None): """ @param next_treelet: the name of the treelet we should run on the next turn if our response/prompt is chosen. None means turn off next turn. @param most_recent_treelet: the name of the treelet that handled this turn, if applicable @@ -59,6 +64,9 @@ def __init__(self, next_treelet: Optional[str] = None, most_recent_treelet: Opti self.bot_labels = bot_labels self.neural_responses = neural_responses self.num_topic_shifts = num_topic_shifts + self.archived_state = archived_state + self.rg_that_was_taken_over = rg_that_was_taken_over + self.takeover_entity = takeover_entity def __repr__(self): return f"" @@ -161,7 +174,7 @@ def update_if_chosen(self, conditional_state: ConditionalState): self.update_conv_history(conditional_state) - def update_if_not_chosen(self, conditional_state: ConditionalState): + def update_if_not_chosen(self, conditional_state: ConditionalState, rg_was_taken_over=False): """If our response/prompt has not been chosen, update state""" # Set the next_treelet for the next turn to be None (off) diff --git a/chirpy/response_generators/offensive_user/offensive_user_response_generator.py b/chirpy/response_generators/offensive_user/offensive_user_response_generator.py index dda40ae..2c3050b 100644 --- a/chirpy/response_generators/offensive_user/offensive_user_response_generator.py +++ b/chirpy/response_generators/offensive_user/offensive_user_response_generator.py @@ -78,7 +78,7 @@ def update_state_if_chosen(self, state: State, conditional_state: Optional[Condi # state[key] += 1 # return state - def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState]) -> BaseState: + def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState], rg_was_taken_over=False) -> BaseState: state = super().update_state_if_not_chosen(state, conditional_state) state.handle_response = False state.offense_type = None diff --git a/chirpy/response_generators/opinion2/opinion_response_generator.py b/chirpy/response_generators/opinion2/opinion_response_generator.py index 853d142..6805d0f 100644 --- a/chirpy/response_generators/opinion2/opinion_response_generator.py +++ b/chirpy/response_generators/opinion2/opinion_response_generator.py @@ -533,7 +533,7 @@ def update_state_if_chosen(self, state: State, conditional_state : Optional[Stat if val != NO_UPDATE: setattr(state, attr, val) return state - def update_state_if_not_chosen(self, state: State, conditional_state : Optional[State]) -> State: + def update_state_if_not_chosen(self, state: State, conditional_state : Optional[State], rg_was_taken_over=False) -> State: new_state = state.reset_state() new_state.num_turns_since_long_policy += 1 return new_state diff --git a/chirpy/response_generators/wiki2/response_templates/response_components.py b/chirpy/response_generators/wiki2/response_templates/response_components.py index 912a20c..771ba73 100644 --- a/chirpy/response_generators/wiki2/response_templates/response_components.py +++ b/chirpy/response_generators/wiki2/response_templates/response_components.py @@ -9,6 +9,7 @@ 'cool', 'super', 'i didn\'t know', + 'interesting', ] GENERAL_BOT_ACKNOWLEDGEMENTS = [ diff --git a/chirpy/response_generators/wiki2/state.py b/chirpy/response_generators/wiki2/state.py index 7ee3a6b..9c4b547 100644 --- a/chirpy/response_generators/wiki2/state.py +++ b/chirpy/response_generators/wiki2/state.py @@ -60,6 +60,7 @@ class State(BaseState): context_used: Optional[str] = None + @dataclass class ConditionalState(BaseConditionalState): # This is only used in conditional state to update the information for each entity diff --git a/chirpy/response_generators/wiki2/treelets/discuss_article_treelet.py b/chirpy/response_generators/wiki2/treelets/discuss_article_treelet.py index caa9ea9..7ffe7a7 100644 --- a/chirpy/response_generators/wiki2/treelets/discuss_article_treelet.py +++ b/chirpy/response_generators/wiki2/treelets/discuss_article_treelet.py @@ -352,7 +352,8 @@ def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): conditional_state=ConditionalState( prev_treelet_str=self.name, next_treelet_str=None - )) + ), + last_rg_willing_to_handover_control=True) else: return ResponseGeneratorResult( text=f"{ack} {text}", diff --git a/chirpy/response_generators/wiki2/treelets/discuss_section_further_treelet.py b/chirpy/response_generators/wiki2/treelets/discuss_section_further_treelet.py index 5bbc527..ff3d9c4 100644 --- a/chirpy/response_generators/wiki2/treelets/discuss_section_further_treelet.py +++ b/chirpy/response_generators/wiki2/treelets/discuss_section_further_treelet.py @@ -207,7 +207,8 @@ def get_initial_response(self): return ResponseGeneratorResult(text=response, priority=ResponsePriority.STRONG_CONTINUE, needs_prompt=False, state=state, - cur_entity=entity, conditional_state=conditional_state) + cur_entity=entity, conditional_state=conditional_state, + last_rg_willing_to_handover_control=True) def get_followup_acknowledgement(self): state, utterance, response_types = self.get_state_utterance_response_types() diff --git a/chirpy/response_generators/wiki2/treelets/handback_treelet.py b/chirpy/response_generators/wiki2/treelets/handback_treelet.py new file mode 100644 index 0000000..5d0acae --- /dev/null +++ b/chirpy/response_generators/wiki2/treelets/handback_treelet.py @@ -0,0 +1,64 @@ +import random + +from chirpy.core.response_generator.treelet import Treelet +from chirpy.core.response_generator_datatypes import ResponsePriority, ResponseGeneratorResult +from chirpy.response_generators.wiki2.state import ConditionalState +from chirpy.response_generators.neural_fallback.neural_helpers import get_random_fallback_neural_response +from typing import Optional +import logging + +import chirpy.response_generators.wiki2.wiki_utils as wiki_utils +from chirpy.response_generators.wiki2.wiki_helpers import ResponseType +from chirpy.core.regex.response_lists import * +from chirpy.response_generators.wiki2.response_templates.response_components import * + +logger = logging.getLogger('chirpylogger') + + +class WikiHandBackTreelet(Treelet): + name = "wiki_handback_treelet" + + def get_acknowledgement(self): + state, utterance, response_types = self.get_state_utterance_response_types() + if ResponseType.CONFUSED in response_types: return random.choice(ERROR_ADMISSION) + + prefix = '' + if ResponseType.AGREEMENT in response_types: + return random.choice(RESPONSES_TO_USER_AGREEMENT) + if ResponseType.POS_SENTIMENT in response_types: + if ResponseType.OPINION in response_types: + prefix = random.choice(POS_OPINION_RESPONSES) + elif ResponseType.APPRECIATIVE in response_types: + return random.choice(APPRECIATION_DEFAULT_ACKNOWLEDGEMENTS) + elif ResponseType.NEG_SENTIMENT in response_types: + if ResponseType.OPINION in response_types: # negative opinion + prefix = "That's an interesting take," + else: # expression of sadness + return random.choice(COMMISERATION_ACKNOWLEDGEMENTS) + elif ResponseType.NEUTRAL_SENTIMENT in response_types: + if ResponseType.OPINION in response_types or ResponseType.PERSONAL_DISCLOSURE in response_types: + return random.choice(NEUTRAL_OPINION_SHARING_RESPONSES) + elif ResponseType.KNOW_MORE: + return "Yeah," + if prefix is not None: + return prefix + return random.choice(POST_SHARING_ACK) + + def get_response(self, priority=ResponsePriority.FORCE_START, **kwargs): + state, utterance, response_types = self.get_state_utterance_response_types() + takeover_entity = state.takeover_entity + + logger.debug(f'WIKI handback_treelet is triggered.') + + wrap_up_text = self.get_acknowledgement() + + return ResponseGeneratorResult( + text=wrap_up_text, + priority=priority, + state=state, needs_prompt=True, cur_entity=self.get_current_entity(), + conditional_state=ConditionalState(prev_treelet_str=self.name, + next_treelet_str=None, + rg_that_was_taken_over=self.rg.state.rg_that_was_taken_over, + takeover_entity=takeover_entity), + ) + diff --git a/chirpy/response_generators/wiki2/treelets/intro_entity_treelet.py b/chirpy/response_generators/wiki2/treelets/intro_entity_treelet.py index 61305cf..3e00a81 100644 --- a/chirpy/response_generators/wiki2/treelets/intro_entity_treelet.py +++ b/chirpy/response_generators/wiki2/treelets/intro_entity_treelet.py @@ -53,7 +53,8 @@ def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): priority=priority, state=state, needs_prompt=False, cur_entity=entity, conditional_state=ConditionalState(prev_treelet_str=self.name, - next_treelet_str=self.rg.discuss_article_treelet.name) + next_treelet_str=self.rg.discuss_article_treelet.name), + last_rg_willing_to_handover_control=True ) else: # no intro paragraph available neural_response = get_random_fallback_neural_response(self.get_current_state()) diff --git a/chirpy/response_generators/wiki2/treelets/takeover_treelet.py b/chirpy/response_generators/wiki2/treelets/takeover_treelet.py new file mode 100644 index 0000000..040a7c2 --- /dev/null +++ b/chirpy/response_generators/wiki2/treelets/takeover_treelet.py @@ -0,0 +1,72 @@ +import random + +from chirpy.core.response_generator.treelet import Treelet +from chirpy.core.response_generator_datatypes import ResponsePriority, ResponseGeneratorResult +from chirpy.response_generators.wiki2.state import ConditionalState +from chirpy.response_generators.neural_fallback.neural_helpers import get_random_fallback_neural_response +from typing import Optional +import logging +import chirpy.response_generators.wiki2.wiki_utils as wiki_utils + + +from chirpy.annotators.blenderbot import BlenderBot + +logger = logging.getLogger('chirpylogger') + + +class WikiTakeOverTreelet(Treelet): + name = "wiki_takeover_treelet" + + def get_response(self, priority=ResponsePriority.FORCE_START, **kwargs): + state, utterance, response_types = self.get_state_utterance_response_types() + + rg_that_was_taken_over = self.rg.state_manager.last_state.active_rg + logger.debug(f'rg that was taken over is {rg_that_was_taken_over}.') + + cur_entity = self.get_current_entity() + takeover_entity = self.get_most_recent_able_to_takeover_entity() + + takeover_text = wiki_utils.get_takeover_text(self.rg, cur_entity, takeover_entity) + + logger.info(f"takenover_text is {takeover_text}") + + if takeover_text: + intro_intersect_text = wiki_utils.get_random_intro_intersect_text(cur_entity.talkable_name, takeover_entity.talkable_name) + starter_text = wiki_utils.get_random_starter_text() + return ResponseGeneratorResult( + text=intro_intersect_text + starter_text + takeover_text, + priority=priority, + state=state, needs_prompt=False, cur_entity=takeover_entity, + conditional_state=ConditionalState(prev_treelet_str=self.name, + next_treelet_str=self.rg.handback_treelet.name, + rg_that_was_taken_over=rg_that_was_taken_over, + takeover_entity=takeover_entity), + takeover_rg_willing_to_handback_control=True + ) + + else: + neural_prefix = f'Speaking of {takeover_entity.talkable_name} and {cur_entity.talkable_name},' + takeover_neural_response = self.rg.get_neural_response(prefix=neural_prefix) + takeover_neural_response = takeover_neural_response.split('.')[0] + generated_response = takeover_neural_response[len(neural_prefix):] + logger.info(f"takenover_neural_response is {takeover_neural_response}") + if takeover_entity.talkable_name in generated_response and cur_entity.talkable_name in generated_response: + intro_intersect_text = wiki_utils.get_random_intro_intersect_text(cur_entity.talkable_name, + takeover_entity.talkable_name) + logger.info("takenover_neural_response is used.") + starter_text = wiki_utils.get_random_starter_text() + return ResponseGeneratorResult( + text=intro_intersect_text + starter_text + generated_response, + priority=priority, + state=state, needs_prompt=False, cur_entity=takeover_entity, + conditional_state=ConditionalState(prev_treelet_str=self.name, + next_treelet_str=self.rg.handback_treelet.name, + rg_that_was_taken_over=rg_that_was_taken_over, + takeover_entity=takeover_entity), + takeover_rg_willing_to_handback_control=True + ) + else: + logger.info("takenover_neural_response is not used because it does not contain takeover_entity and cur_entity in it.") + logger.info( + "WIKI fails to takeover.") + return None \ No newline at end of file diff --git a/chirpy/response_generators/wiki2/wiki_response_generator.py b/chirpy/response_generators/wiki2/wiki_response_generator.py index b2c7436..df30c9d 100644 --- a/chirpy/response_generators/wiki2/wiki_response_generator.py +++ b/chirpy/response_generators/wiki2/wiki_response_generator.py @@ -1,5 +1,7 @@ import os import logging +from concurrent import futures + from typing import Optional, Set, Tuple import random @@ -23,6 +25,10 @@ from chirpy.annotators.corenlp import Sentiment from chirpy.response_generators.wiki2.state import State,ConditionalState, NO_UPDATE +from chirpy.response_generators.wiki2.treelets.takeover_treelet import WikiTakeOverTreelet +from chirpy.response_generators.wiki2.treelets.handback_treelet import WikiHandBackTreelet + +from chirpy.core.offensive_classifier.offensive_classifier import OffensiveClassifier logger = logging.getLogger('chirpylogger') @@ -30,15 +36,17 @@ from chirpy.annotators.responseranker import ResponseRanker use_responseranker = True except ModuleNotFoundError: - logger.warning('ResponseRanker module not found, defaulting to original DialoGPT and GPT2 Rankers') + logger.warning('ResponseRanker module not found, defaulting to original DialoGPT and GP T2 Rankers') from chirpy.annotators.dialogptranker import DialoGPTRanker from chirpy.annotators.gpt2ranker import GPT2Ranker use_responseranker = False +import threading + class WikiResponseGenerator(ResponseGenerator): name='WIKI' - killable = True + killable = False def __init__(self, state_manager) -> None: self.check_user_knowledge_treelet = CheckUserKnowledgeTreelet(self) self.acknowledge_user_knowledge_treelet = AcknowledgeUserKnowledgeTreelet(self) @@ -50,12 +58,15 @@ def __init__(self, state_manager) -> None: self.discuss_section_treelet = DiscussSectionTreelet(self) self.discuss_section_further_treelet = DiscussSectionFurtherTreelet(self) self.get_opinion_treelet = GetOpinionTreelet(self) + self.takeover_treelet = WikiTakeOverTreelet(self) + self.handback_treelet = WikiHandBackTreelet(self) treelets = {t.name: t for t in [self.check_user_knowledge_treelet, self.acknowledge_user_knowledge_treelet, self.factoid_treelet, self.intro_entity_treelet, self.combined_til_treelet, self.discuss_article_treelet, self.discuss_section_treelet, - self.discuss_section_further_treelet, self.get_opinion_treelet]} + self.discuss_section_further_treelet, self.get_opinion_treelet, + self.takeover_treelet, self.handback_treelet]} super().__init__(state_manager, treelets=treelets, state_constructor=State, can_give_prompts=True, conditional_state_constructor=ConditionalState, @@ -641,7 +652,7 @@ def update_state_if_chosen(self, state: State, conditional_state: Optional[Condi return state - def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState]) -> State: + def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState], rg_was_taken_over=False) -> State: state = super().update_state_if_not_chosen(state, conditional_state) state.cur_doc_title = None state.suggested_sections = [] @@ -658,3 +669,15 @@ def update_state_if_not_chosen(self, state: State, conditional_state: Optional[C state.context_used = None return state + + def get_takeover_response(self): + logger.error("WIKI TAKEOVER") + return self.takeover_treelet.get_response(ResponsePriority.FORCE_START) + + def get_neural_response(self, prefix=None, allow_questions=False, conditions=None) -> Optional[str]: + if conditions is None: conditions = [] + offensive_classifier = OffensiveClassifier() + conditions = [lambda response: not offensive_classifier.contains_offensive(response)] + conditions + response = super().get_neural_response(prefix, allow_questions, conditions) + if response is None: return "That's great to hear." + return response \ No newline at end of file diff --git a/chirpy/response_generators/wiki2/wiki_utils.py b/chirpy/response_generators/wiki2/wiki_utils.py index 0e8ea8a..7a6ede8 100644 --- a/chirpy/response_generators/wiki2/wiki_utils.py +++ b/chirpy/response_generators/wiki2/wiki_utils.py @@ -23,6 +23,8 @@ import random import math +from chirpy.core.entity_linker.entity_linker_classes import WikiEntity + lucene_stopwords = {'a', 'an', 'and', 'are', 'as', 'at', 'be', 'but', 'by', 'for', 'if', 'in', 'into', 'is', 'it', 'no', 'not', 'of', 'on', 'or', 'such', 'that', 'the', 'their', 'then', 'there', 'these', 'they', 'this', 'to', 'was', @@ -135,6 +137,8 @@ def filter_highlight_sections(title, es_sections): wiki_sections.append(wiki_section) filtered_sections = wiki_sections + logger.error(f"SECTIONS: {filtered_sections}") + # Filter sections and log why the were filtered filtered_sections = filter_and_log(lambda section: not contains_offensive(section.title), filtered_sections, 'Wiki Highlights', reason_for_filtering='section title contains offensive phrases') @@ -162,7 +166,10 @@ def filter_highlight_sections(title, es_sections): filtered_sections = filter_and_log(lambda section: not contains_offensive(section.highlight), filtered_sections, 'Wiki Highlights', reason_for_filtering='section highlight contains offensive phrases') + logger.error(f"SECTIONS2: {filtered_sections}") return filtered_sections + + def filter_sections(title, es_sections): wiki_sections = [] for section in es_sections['hits']['hits']: @@ -262,12 +269,107 @@ def search_wiki_sections(doc_title: str, phrases: tuple, wiki_links:tuple) -> Li } } } + import json + logger.error(f"QUERY: {json.dumps(query, indent=2)}") sections = es.search(index='enwiki-20200920-sections', body=query) - logger.debug(f"For phrases {phrases}, in wikipedia article {doc_title}, found following sections (unfiltered) {sections}") + logger.error(f"For phrases {phrases}, in wikipedia article {doc_title}, found following sections (unfiltered) {sections}") filtered_sections = filter_highlight_sections(doc_title, sections) return filtered_sections +def prune_section(section): + return section['text'][0] in {'†', '+', '*'} + + +def clean_takeover_wiki_text(text: str) -> str: + modified_text = clean_wiki_text(text) + index_caption = modified_text.find(']]') + if index_caption != -1: + modified_text = modified_text[index_caption + 2:] + return modified_text + + +def summarize_takeover_candidate_text(rg, text: str, span_to_keep: str, max_words: int = 50, max_sents: int = 3) -> str: + logger.info(f'Summarizing takeover text: {text}') + + local_sentseg_fn = lambda text: re.split('[.\n]', text) + sentseg_fn = NLTKSentenceSegmenter( + rg.state_manager).execute if rg.state_manager else local_sentseg_fn + sentences = sentseg_fn(text) + + summary = '' + num_sentences = 0 + found = False + for sentence in sentences: + if sentence == '': + continue + if "|" in sentence or "[" in sentence or "]" in sentence or "{" in sentence or "}" in sentence: + continue + if span_to_keep in sentence: + found = True + summary += sentence + ('.' if sentence[-1] not in {'.', '!', '?'} else ' ') + num_sentences += 1 + if found and (num_sentences > max_sents or len(summary.split(' ')) < max_words): + break + return summary + + +def search_wiki_intersect_sections(rg, doc_title: str, search_entity: WikiEntity) -> List[str]: + query = {'query': {'bool': {'filter': [ + {'term': {'doc_title': doc_title}}]}}} + sections = es.search(index='enwiki-20200920-sections', body=query, size=100) + top_spans = list(search_entity.anchortext_counts.keys())[:3] + logger.error(f"SPAN: {search_entity.anchortext_counts}") + candidate_texts = [] + # logger.error(f"TEXTS: {sections['hits']['hits']}") + for section in sections['hits']['hits']: + source = section['_source'] + if not prune_section(source): + source_texts = list(filter(None, re.split('\n', source['text']))) + for text in source_texts: + cleaned_text = clean_takeover_wiki_text(text) + if not contains_offensive(cleaned_text) and not rg.has_overlap_with_history(cleaned_text, threshold=0.8): + for s in top_spans: + if s in cleaned_text: + summarized_text = summarize_takeover_candidate_text(rg, cleaned_text, span_to_keep=s) + candidate_texts.append(summarized_text) + break + logger.info(f"candidate_texts from doc_title {doc_title} is {candidate_texts}") + return candidate_texts + + +def get_takeover_text(rg, cur_entity: WikiEntity, takeover_entity: WikiEntity) -> Optional[str]: + related_wiki_texts_from_cur_entity_doc = search_wiki_intersect_sections(rg, cur_entity.talkable_name, takeover_entity) + if related_wiki_texts_from_cur_entity_doc: + return random.choice(related_wiki_texts_from_cur_entity_doc) + + related_wiki_texts_from_takeover_entity_doc = search_wiki_intersect_sections(rg, takeover_entity.talkable_name, cur_entity) + if related_wiki_texts_from_takeover_entity_doc: + return random.choice(related_wiki_texts_from_takeover_entity_doc) + + return None + + +INTRO_INTERSECT_TEXT = ["Speaking of {} and {}, ", + "Relating to {} and {}, ", + "Since you mentioned {} and {}, "] + + +def get_random_intro_intersect_text(cur_entity: str, takeover_entity: str) -> str: + return random.choice(INTRO_INTERSECT_TEXT).format(cur_entity, takeover_entity) + + +STARTER_TEXTS = ["did you know that ", + "I recently learned that ", + "I was reading recently and found out that ", + "did you know that ", + "I was interested to learn that "] + + +def get_random_starter_text(): + return random.choice(STARTER_TEXTS) + + def get_text_for_entity(entity): results = es.search(index='enwiki-20200920-sections', body={ 'query': { @@ -291,6 +393,7 @@ def replaceByLength(matchobj): sections = sorted(sections, key=(lambda x: -len(x[1]))) return sections + def check_section_summary(rg, section_summary, selected_section, allow_history_overlap=False): """ Check that the section summary is present, non-offensive, and does not overlap with history. diff --git a/env.list b/env.list new file mode 100644 index 0000000..119b5f5 --- /dev/null +++ b/env.list @@ -0,0 +1,22 @@ +export PYTHONPATH=$(pwd) +export ES_USER=chirpy1 +export ES_PASSWORD=4sMoNKNxQkMeVtrlEYqsK2Nzo7kBNU@ +export ES_HOST=search-genie-search-dev-36ydzvzvwb7oyyzvdbrs63rdny.us-west-2.es.amazonaws.com +export ES_PORT=443 +export ES_SCHEME=https +export POSTGRES_HOST=localhost +export POSTGRES_USER=postgres +export POSTGRES_PASSWORD=qyhqae-4Sepzy-zecget +export corenlp_URL=4080 +export dialogact_URL=4081 +export g2p_URL=4082 +export gpt2ed_URL=4083 +export question_URL=4084 +export convpara_URL=4085 +export entitylinker_URL=4086 +export blenderbot_URL=4087 +export responseranker_URL=4088 +export stanfordnlp_URL=4089 +export infiller_URL=4090 +export postgresql_URL=5432 +export usecolbert=false