001/*- 002 ******************************************************************************* 003 * Copyright (c) 2011, 2016 Diamond Light Source Ltd. 004 * All rights reserved. This program and the accompanying materials 005 * are made available under the terms of the Eclipse Public License v1.0 006 * which accompanies this distribution, and is available at 007 * http://www.eclipse.org/legal/epl-v10.html 008 * 009 * Contributors: 010 * Peter Chang - initial API and implementation and/or initial documentation 011 *******************************************************************************/ 012 013package org.eclipse.january.dataset; 014 015import java.util.Arrays; 016import java.util.List; 017 018/** 019 * Class to run over a pair of datasets in parallel with NumPy broadcasting to promote shapes 020 * which have lower rank and outputs to a third dataset 021 */ 022public class BroadcastPairIterator extends BroadcastIterator { 023 private int[] aShape; 024 private int[] bShape; 025 private int[] aStride; 026 private int[] bStride; 027 private int[] oStride; 028 029 final private int endrank; 030 031 private final int[] aDelta, bDelta; 032 private final int[] oDelta; // this being non-null means output is different from inputs 033 private final int aStep, bStep, oStep; 034 private int aMax, bMax; 035 private int aStart, bStart, oStart; 036 037 /** 038 * 039 * @param a 040 * @param b 041 * @param o (can be null for new dataset, a or b) 042 * @param createIfNull 043 */ 044 public BroadcastPairIterator(Dataset a, Dataset b, Dataset o, boolean createIfNull) { 045 super(a, b, o); 046 List<int[]> fullShapes = BroadcastUtils.broadcastShapes(a.getShapeRef(), b.getShapeRef(), o == null ? null : o.getShapeRef()); 047 048 maxShape = fullShapes.remove(0); 049 050 oStride = null; 051 if (o != null && !Arrays.equals(maxShape, o.getShapeRef())) { 052 throw new IllegalArgumentException("Output does not match broadcasted shape"); 053 } 054 aShape = fullShapes.remove(0); 055 bShape = fullShapes.remove(0); 056 057 int rank = maxShape.length; 058 endrank = rank - 1; 059 060 aDataset = a.reshape(aShape); 061 bDataset = b.reshape(bShape); 062 aStride = BroadcastUtils.createBroadcastStrides(aDataset, maxShape); 063 bStride = BroadcastUtils.createBroadcastStrides(bDataset, maxShape); 064 if (outputA) { 065 oStride = aStride; 066 oDelta = null; 067 oStep = 0; 068 } else if (outputB) { 069 oStride = bStride; 070 oDelta = null; 071 oStep = 0; 072 } else if (o != null) { 073 oStride = BroadcastUtils.createBroadcastStrides(o, maxShape); 074 oDelta = new int[rank]; 075 oStep = o.getElementsPerItem(); 076 } else if (createIfNull) { 077 oDataset = BroadcastUtils.createDataset(a, b, maxShape); 078 oStride = BroadcastUtils.createBroadcastStrides(oDataset, maxShape); 079 oDelta = new int[rank]; 080 oStep = oDataset.getElementsPerItem(); 081 } else { 082 oDelta = null; 083 oStep = 0; 084 } 085 086 pos = new int[rank]; 087 aDelta = new int[rank]; 088 aStep = aDataset.getElementsPerItem(); 089 bDelta = new int[rank]; 090 bStep = bDataset.getElementsPerItem(); 091 for (int j = endrank; j >= 0; j--) { 092 aDelta[j] = aStride[j] * aShape[j]; 093 bDelta[j] = bStride[j] * bShape[j]; 094 if (oDelta != null) { 095 oDelta[j] = oStride[j] * maxShape[j]; 096 } 097 } 098 if (endrank < 0) { 099 aMax = aStep; 100 bMax = bStep; 101 } else { 102 aMax = Integer.MIN_VALUE; // use max delta 103 bMax = Integer.MIN_VALUE; 104 for (int j = endrank; j >= 0; j--) { 105 if (aDelta[j] > aMax) { 106 aMax = aDelta[j]; 107 } 108 if (bDelta[j] > bMax) { 109 bMax = bDelta[j]; 110 } 111 } 112 } 113 aStart = aDataset.getOffset(); 114 aMax += aStart; 115 bStart = bDataset.getOffset(); 116 bMax += bStart; 117 oStart = oDelta == null ? 0 : oDataset.getOffset(); 118 reset(); 119 } 120 121 @Override 122 public boolean hasNext() { 123 int j = endrank; 124 int oldA = aIndex; 125 int oldB = bIndex; 126 for (; j >= 0; j--) { 127 pos[j]++; 128 aIndex += aStride[j]; 129 bIndex += bStride[j]; 130 if (oDelta != null) 131 oIndex += oStride[j]; 132 if (pos[j] >= maxShape[j]) { 133 pos[j] = 0; 134 aIndex -= aDelta[j]; // reset these dimensions 135 bIndex -= bDelta[j]; 136 if (oDelta != null) 137 oIndex -= oDelta[j]; 138 } else { 139 break; 140 } 141 } 142 if (j == -1) { 143 if (endrank >= 0) { 144 aIndex = aMax; 145 bIndex = bMax; 146 return false; 147 } 148 aIndex += aStep; 149 bIndex += bStep; 150 if (oDelta != null) 151 oIndex += oStep; 152 } 153 if (outputA) { 154 oIndex = aIndex; 155 } else if (outputB) { 156 oIndex = bIndex; 157 } 158 159 if (aIndex == aMax || bIndex == bMax) 160 return false; 161 162 if (read) { 163 if (oldA != aIndex) { 164 if (asDouble) { 165 aDouble = aDataset.getElementDoubleAbs(aIndex); 166 } else { 167 aLong = aDataset.getElementLongAbs(aIndex); 168 } 169 } 170 if (oldB != bIndex) { 171 if (asDouble) { 172 bDouble = bDataset.getElementDoubleAbs(bIndex); 173 } else { 174 bLong = bDataset.getElementLongAbs(bIndex); 175 } 176 } 177 } 178 179 return true; 180 } 181 182 /** 183 * @return shape of first broadcasted dataset 184 */ 185 public int[] getFirstShape() { 186 return aShape; 187 } 188 189 /** 190 * @return shape of second broadcasted dataset 191 */ 192 public int[] getSecondShape() { 193 return bShape; 194 } 195 196 @Override 197 public void reset() { 198 for (int i = 0; i <= endrank; i++) 199 pos[i] = 0; 200 201 if (endrank >= 0) { 202 pos[endrank] = -1; 203 aIndex = aStart - aStride[endrank]; 204 bIndex = bStart - bStride[endrank]; 205 oIndex = oStart - (oStride == null ? 0 : oStride[endrank]); 206 } else { 207 aIndex = aStart - aStep; 208 bIndex = bStart - bStep; 209 oIndex = oStart - oStep; 210 } 211 212 if (aIndex == 0 || bIndex == 0) { // for zero-ranked datasets 213 if (read) { 214 storeCurrentValues(); 215 } 216 if (aMax == aIndex) 217 aMax++; 218 if (bMax == bIndex) 219 bMax++; 220 } 221 } 222}