const Matrix = require('../..');
const isNumber = require('../../util/isNumber');
const { INVALID_ARRAY, EXPECTED_ARRAY_OF_NUMBERS_OR_MATRICES, INVALID_SQUARE_MATRIX } = require('../../Error');
/**
* Generates diagonal Matrix if the argument is an array of numbers,
* generates block diagonal Matrix if the argument is an array of Matrices.
* @memberof Matrix
* @static
* @param {(number[]|Matrix[])} values - Array of numbers or Matrices
* @returns {Matrix} Block diagonal Matrix
*/
function diag(values) {
if (!Array.isArray(values)) {
throw new Error(INVALID_ARRAY);
}
const argsNum = values.length;
let variant;
for (let i = 0; i < argsNum; i++) {
const entry = values[i];
if (!isNumber(entry) && !(entry instanceof Matrix)) {
throw new Error(EXPECTED_ARRAY_OF_NUMBERS_OR_MATRICES);
}
if (isNumber(entry)) {
if (!variant) {
variant = 'number';
continue;
}
if (variant !== 'number') {
throw new Error(EXPECTED_ARRAY_OF_NUMBERS_OR_MATRICES);
}
} else {
if (!entry.isSquare()) {
throw new Error(INVALID_SQUARE_MATRIX);
}
if (!variant) {
variant = 'square';
continue;
}
if (variant !== 'square') {
throw new Error(EXPECTED_ARRAY_OF_NUMBERS_OR_MATRICES);
}
}
}
// HERE: variant should be either 'number' or 'square'
if (variant === 'number') {
return Matrix.generate(argsNum, argsNum, (i, j) => {
if (i === j) {
return values[i];
}
return 0;
});
}
// Guaranteed that [values] is a list of square matrices
let size = 0;
const temp = new Array(argsNum);
for (let i = 0; i < argsNum; i++) {
const len = values[i].size()[0];
size += len;
temp[i] = len;
}
let idx = 0;
let start = 0;
let len = temp[idx];
return Matrix.generate(size, size, (i, j) => {
if (i - start === len && j - start === len) {
start += len;
idx++;
}
const ith = i - start; // ith < 0 if below main diagonal
const jth = j - start; // jth < 0 if above main diagonal
// skip 0x0 matrices
len = temp[idx];
while (len === 0) {
idx++;
len = temp[idx];
}
if ((ith < len && ith >= 0) && (jth < len && jth >= 0)) {
return values[idx]._matrix[ith][jth];
}
return 0;
});
};
module.exports = diag;