From b141fdb016f58749bcadd8b3d11e6025e815338e Mon Sep 17 00:00:00 2001 From: EvaLiyt Date: Tue, 11 Jun 2024 13:52:52 +1200 Subject: [PATCH] implement labelClade function #494 --- .../lphy/base/evolution/tree/LabelClade.java | 49 ++++++++++++++ .../base/evolution/tree/SubstituteClade.java | 20 ------ .../main/java/lphy/base/spi/LPhyBaseImpl.java | 7 +- .../base/evolution/tree/LabelCladeTest.java | 66 +++++++++++++++++++ 4 files changed, 117 insertions(+), 25 deletions(-) create mode 100644 lphy-base/src/main/java/lphy/base/evolution/tree/LabelClade.java create mode 100644 lphy-base/src/test/java/lphy/base/evolution/tree/LabelCladeTest.java diff --git a/lphy-base/src/main/java/lphy/base/evolution/tree/LabelClade.java b/lphy-base/src/main/java/lphy/base/evolution/tree/LabelClade.java new file mode 100644 index 00000000..d7642408 --- /dev/null +++ b/lphy-base/src/main/java/lphy/base/evolution/tree/LabelClade.java @@ -0,0 +1,49 @@ +package lphy.base.evolution.tree; + +import lphy.core.model.DeterministicFunction; +import lphy.core.model.Value; +import lphy.core.model.annotation.GeneratorInfo; +import lphy.core.model.annotation.ParameterInfo; + +public class LabelClade extends DeterministicFunction { + public static final String treeParamName = "tree"; + public static final String taxaParamName = "taxa"; + public static final String labelParamName = "label"; + public LabelClade(@ParameterInfo(name = treeParamName, description = "the tree to label")Value tree, + @ParameterInfo(name = taxaParamName, description = "the root of the taxa names would be labelled") Value taxa, + @ParameterInfo(name = labelParamName, description = "the label") Value label){ + if (tree == null) throw new IllegalArgumentException("The tree cannot be null!"); + if (taxa == null) throw new IllegalArgumentException("The taxa name cannot be null!"); + if (label == null) throw new IllegalArgumentException("Please label the mrca of the taxa!"); + setParam(treeParamName, tree); + setParam(taxaParamName, taxa); + setParam(labelParamName, label); + } + @GeneratorInfo(name = "labelClade", description = "Find the most recent common ancestor of given taxa names in the tree and give it a label.") + @Override + public Value apply() { + // make a deep copy of the tree + TimeTree tree = getTree().value(); + TimeTree newTree = new TimeTree(tree); + + // find mrca node + Value treeValue = new Value<>("newTree", newTree); + MRCA mrcaInstance = new MRCA(treeValue, getTaxa()); + TimeTreeNode mrca = mrcaInstance.apply().value(); + + // set label metadata + String label = getLabel().value(); + mrca.setMetaData("label", label); + + return new Value<>(null, newTree,this); + } + public Value getTree(){ + return getParams().get(treeParamName); + } + public Value getTaxa(){ + return getParams().get(taxaParamName); + } + public Value getLabel(){ + return getParams().get(labelParamName); + } +} diff --git a/lphy-base/src/main/java/lphy/base/evolution/tree/SubstituteClade.java b/lphy-base/src/main/java/lphy/base/evolution/tree/SubstituteClade.java index e1e2fd31..eabef86d 100644 --- a/lphy-base/src/main/java/lphy/base/evolution/tree/SubstituteClade.java +++ b/lphy-base/src/main/java/lphy/base/evolution/tree/SubstituteClade.java @@ -6,11 +6,6 @@ import lphy.core.model.annotation.ParameterInfo; public class SubstituteClade extends DeterministicFunction { -// Value baseTree; -// Value cladeTree; -// Value time; -// Value node; -// Value nodeLabel; public static final String baseTreeName = "baseTree"; public static final String cladeTreeName = "cladeTree"; public static final String nodeName = "node"; @@ -34,11 +29,6 @@ public SubstituteClade(@ParameterInfo(name = baseTreeName, description = "the tr if (time != null) setParam(mutationHappenTimeName, time); setParam(nodeLabelName,nodeLabel); -// this.baseTree = baseTree; -// this.cladeTree = cladeTree; -// this.node = node; -// this.nodeLabel = nodeLabel; -// this.time = time; } @GeneratorInfo(name = "substituteClade", examples = {"substituteClade.lphy"}, @@ -93,16 +83,6 @@ public Value apply() { return new Value<>(null, newTree, this); } -// @Override -// public Map getParams() { -// SortedMap map = new TreeMap<>(); -// if (baseTree != null) map.put(baseTreeName, baseTree); -// if (cladeTree != null) map.put(cladeTreeName, cladeTree); -// if (node != null) map.put(nodeName, node); -// if (time != null) map.put(mutationHappenTimeName, time); -// if (nodeLabelName != null) map.put(nodeLabelName, nodeLabel); -// return map; -// } public Value getBaseTree() { return getParams().get(baseTreeName); } diff --git a/lphy-base/src/main/java/lphy/base/spi/LPhyBaseImpl.java b/lphy-base/src/main/java/lphy/base/spi/LPhyBaseImpl.java index 1313a988..99f80494 100644 --- a/lphy-base/src/main/java/lphy/base/spi/LPhyBaseImpl.java +++ b/lphy-base/src/main/java/lphy/base/spi/LPhyBaseImpl.java @@ -20,10 +20,7 @@ import lphy.base.evolution.likelihood.PhyloCTMC; import lphy.base.evolution.likelihood.PhyloCTMCSiteModel; import lphy.base.evolution.substitutionmodel.*; -import lphy.base.evolution.tree.MRCA; -import lphy.base.evolution.tree.SampleBranch; -import lphy.base.evolution.tree.SubsampledTree; -import lphy.base.evolution.tree.SubstituteClade; +import lphy.base.evolution.tree.*; import lphy.base.function.*; import lphy.base.function.alignment.*; import lphy.base.function.datatype.AminoAcidsFunction; @@ -97,7 +94,7 @@ public List> declareFunctions() { VariableSites.class, InvariableSites.class, CopySites.class, // Tree LocalBranchRates.class, ExtantTree.class, PruneTree.class, LocalClock.class, - SubstituteClade.class, MRCA.class,//NodeCount.class, TreeLength.class, + SubstituteClade.class, MRCA.class, LabelClade.class,//NodeCount.class, TreeLength.class, // Matrix BinaryRateMatrix.class, MigrationMatrix.class, MigrationCount.class, // IO diff --git a/lphy-base/src/test/java/lphy/base/evolution/tree/LabelCladeTest.java b/lphy-base/src/test/java/lphy/base/evolution/tree/LabelCladeTest.java new file mode 100644 index 00000000..16a679a2 --- /dev/null +++ b/lphy-base/src/test/java/lphy/base/evolution/tree/LabelCladeTest.java @@ -0,0 +1,66 @@ +package lphy.base.evolution.tree; + +import lphy.base.function.tree.Newick; +import lphy.core.model.Value; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class LabelCladeTest { + String newickTree; + + @BeforeEach + void setUp() { + newickTree = "(((((1:2.0, (2:1.0, 3:1.0):1.0):2.0, (5:2.0, 6:2.0):2.0):2.0):0.0,4:6.0):6.0, 7:12.0)"; + } + + @Test + void applyTest1() { + TimeTree tree = Newick.parseNewick(newickTree); + String[] taxa = {"4", "7"}; + String label = "label"; + + Value treeValue = new Value<>("tree", tree); + Value taxaValue = new Value<>("taxa", taxa); + Value labelValue = new Value<>("label", label); + + LabelClade labelCladeInstance = new LabelClade(treeValue, taxaValue, labelValue); + MRCA mrcaInstance = new MRCA(treeValue,taxaValue); + TimeTree observe = labelCladeInstance.apply().value(); + TimeTreeNode mrcaNode = mrcaInstance.apply().value(); + + // only index should be same for mrca + assertEquals(mrcaNode.getIndex(), observe.getRoot().getIndex()); + // check label + assertEquals(label, observe.getRoot().getMetaData("label")); + assertEquals(observe.getRoot(), observe.getLabeledNode(label)); + } + + @Test + void applyTest2() { + TimeTree tree = Newick.parseNewick(newickTree); + String[] taxa = {"3", "5"}; + String label = "label"; + + Value treeValue = new Value<>("tree", tree); + Value taxaValue = new Value<>("taxa", taxa); + Value labelValue = new Value<>("label", label); + + LabelClade labelCladeInstance = new LabelClade(treeValue, taxaValue, labelValue); + MRCA mrcaInstance = new MRCA(treeValue,taxaValue); + TimeTree observe = labelCladeInstance.apply().value(); + TimeTreeNode mrcaNode = mrcaInstance.apply().value(); + + boolean found = false; + for (TimeTreeNode node: observe.getInternalNodes()){ + if (node.getAllLeafNodes().size() == 5 && node.age == 4){ + assertEquals(node.getIndex(), mrcaNode.getIndex()); + assertEquals(label, node.getMetaData("label")); + assertEquals(node, observe.getLabeledNode(label)); + found = true; + } + } + assert(found); + } +}