Source code for jaclearn.nlp.tree.constituency

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : constituency.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 07/04/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

"""
Constituency Tree.
"""

import jacinle.random as random
from jacinle.utils.enum import JacEnum

from .ptb import PTBNode
from .traversal import traversal

TEMP_NODE = '<TEMP>'


def _new_temp_node(token=None):
    return PTBNode(TEMP_NODE, token)


[docs] def binarize_tree(tree): def dc(root, children): n = len(children) if n == 1: return children[0] lhs = children[:n // 2] rhs = children[n // 2:] for part in [lhs, rhs]: if len(part) == 1: root.append_child(part[0]) else: imm = _new_temp_node() imm.attach(root) dc(imm, part) def dfs(node): for x in node.children: dfs(x) n = len(node.children) if n == 0: pass elif n == 1: y, z = node, node.children[0] x, sibling_ind = y.detach() z.detach() z.vtype = y.vtype if x is None: node = z else: z.attach(x, sibling_ind) elif n == 2: pass else: children = node.children.copy() for x in children: x.detach() dc(node, children) return node return dfs(tree.clone())
[docs] def make_balanced_binary_tree(sequence): root = _new_temp_node() for x in sequence: _new_temp_node(x).attach(root) return binarize_tree(root)
[docs] class StepMaskSelectionMode(JacEnum): FIRST = 'first' RANDOM = 'random'
[docs] def compose_bianry_tree_step_masks(tree, selection='first'): selection = StepMaskSelectionMode.from_string(selection) nodes = list(traversal(tree, 'pre')) clean_nodes = {x for x in nodes if x.is_leaf} ever_clean_nodes = clean_nodes.copy() answer = [] while len(clean_nodes) > 1: # all allowed nodes allowed = {x: i for i, x in enumerate(nodes) if ( x not in ever_clean_nodes and all(map(lambda y: y in clean_nodes, x.children)) )} # project it to allowed_projected = {x for x in clean_nodes if ( x.sibling_ind == 0 and x.father in allowed )} ordered_clean_nodes = [x for x in nodes if x in clean_nodes] clean_nodes_indices = {x: i for i, x in enumerate(ordered_clean_nodes)} if selection is StepMaskSelectionMode.FIRST: selected = nodes[min(allowed.values())] elif selection is StepMaskSelectionMode.RANDOM: selected = random.choice_list(list(allowed)) else: raise ValueError('Unknown StepMaskSelectionMode: {}.'.format(selection)) mask_allowed_projected = [1 if x in allowed_projected else 0 for x in ordered_clean_nodes] assert len(selected.children) == 2 # sanity check. lson = clean_nodes_indices[selected.children[0]] rson = clean_nodes_indices[selected.children[1]] assert lson + 1 == rson clean_nodes.difference_update(selected.children) clean_nodes.add(selected) ever_clean_nodes.add(selected) answer.append((lson, mask_allowed_projected)) return answer