leetcode 1277 - Count Square Submatrices with All Ones

https://leetcode.com/problems/count-square-submatrices-with-all-ones/

Given a m * n matrix of ones and zeros, return how many square submatrices have all ones.

Example 1:

Input: matrix =
[
[0,1,1,1],
[1,1,1,1],
[0,1,1,1]
]
Output: 15
Explanation:
There are 10 squares of side 1.
There are 4 squares of side 2.
There is 1 square of side 3.
Total number of squares = 10 + 4 + 1 = 15.
Example 2:

Input: matrix =
[
[1,0,1],
[1,1,0],
[1,1,0]
]
Output: 7
Explanation:
There are 6 squares of side 1.
There is 1 square of side 2.
Total number of squares = 6 + 1 = 7.

Constraints:

1 <= arr.length <= 300
1 <= arr[0].length <= 300
0 <= arr[i][j] <= 1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
//not optimized
class Solution {
public:
int countSquares(vector<vector<int>>& matrix) {
int rows = matrix.size();
int cols = matrix[0].size();

sum.resize(rows+1, vector<int>(cols+1, 0));

for (int r = 1; r <= rows; r++) {
for (int c = 1; c <= cols; c++) {
sum[r][c] = sum[r-1][c] + sum[r][c-1] - sum[r-1][c-1] + matrix[r-1][c-1];
}
}

int ans = 0;
for (int r = 0; r < rows; r++) {
for (int c = 0; c < cols; c++) {
if (matrix[r][c] != 1) continue;

int x = r, y = c;
while (x < rows && y < cols) {
int ones = get_sum(r, c, x, y);
if (ones == (x-r+1)*(x-r+1)) {
ans++;
x++;
y++;
} else {
break;
}
}
}
}

return ans;
}

int get_sum(int r1, int c1, int r2, int c2) {
return sum[r2+1][c2+1] - sum[r1][c2+1] - sum[r2+1][c1] + sum[r1][c1];
}

private:
vector<vector<int>> sum;
};

//O(rows * cols)
class Solution {
public:
int countSquares(vector<vector<int>>& matrix) {
int rows = matrix.size();
int cols = matrix[0].size();

int ans = 0;

for (int r = 0; r < rows; r++) {
for (int c = 0; c < cols; c++) {
if (matrix[r][c] != 1) continue;

if (r > 0 && c > 0)
matrix[r][c] = min({matrix[r-1][c], matrix[r][c-1], matrix[r-1][c-1]}) + 1;

ans += matrix[r][c];
}
}

return ans;
}
};