8

Stavo seguendo il corso this sugli algoritmi del MIT. Nella prima lezione il professore presenta il seguente problema: -Algoritmo di individuazione dei picchi 2D in tempo O (n) nel caso peggiore?

Un picco in un array 2D è un valore tale che tutti i suoi 4 vicini sono inferiori o uguali ad esso, cioè. per

a[i][j] per essere un massimo locale,

a[i+1][j] <= a[i][j] 
&& a[i-1][j] <= a[i][j] 
&& a[i][j+1] <= a[i][j] 
&& a[i+1][j-1] <= a[i][j] 

Ora dato un array NxN 2D, trovare un picco nella matrice.

Questa domanda può essere facilmente risolta nel tempo O(N^2) iterando su tutti gli elementi e restituendo un picco.

Tuttavia può essere ottimizzato per essere risolto nel tempo O(NlogN) utilizzando una soluzione divide e conquista come spiegato here.

Ma hanno detto che esiste un algoritmo di tempo O(N) che risolve questo problema. Si prega di suggerire come possiamo risolvere questo problema nel tempo O(N).

PS (per coloro che conoscono Python) Lo staff del corso ha spiegato un approccio here (Problema 1-5. Prova di ricerca del picco) e fornito anche del codice python nei loro set di problemi. Ma l'approccio spiegato è totalmente non ovvio e molto difficile da decifrare. Il codice Python è altrettanto confuso. Quindi ho copiato la parte principale del codice qui sotto per coloro che conoscono Python e posso dire quale algoritmo viene utilizzato dal codice.

def algorithm4(problem, bestSeen = None, rowSplit = True, trace = None): 
    # if it's empty, we're done 
    if problem.numRow <= 0 or problem.numCol <= 0: 
     return None 

    subproblems = [] 
    divider = [] 

    if rowSplit: 
     # the recursive subproblem will involve half the number of rows 
     mid = problem.numRow // 2 

     # information about the two subproblems 
     (subStartR1, subNumR1) = (0, mid) 
     (subStartR2, subNumR2) = (mid + 1, problem.numRow - (mid + 1)) 
     (subStartC, subNumC) = (0, problem.numCol) 

     subproblems.append((subStartR1, subStartC, subNumR1, subNumC)) 
     subproblems.append((subStartR2, subStartC, subNumR2, subNumC)) 

     # get a list of all locations in the dividing column 
     divider = crossProduct([mid], range(problem.numCol)) 
    else: 
     # the recursive subproblem will involve half the number of columns 
     mid = problem.numCol // 2 

     # information about the two subproblems 
     (subStartR, subNumR) = (0, problem.numRow) 
     (subStartC1, subNumC1) = (0, mid) 
     (subStartC2, subNumC2) = (mid + 1, problem.numCol - (mid + 1)) 

     subproblems.append((subStartR, subStartC1, subNumR, subNumC1)) 
     subproblems.append((subStartR, subStartC2, subNumR, subNumC2)) 

     # get a list of all locations in the dividing column 
     divider = crossProduct(range(problem.numRow), [mid]) 

    # find the maximum in the dividing row or column 
    bestLoc = problem.getMaximum(divider, trace) 
    neighbor = problem.getBetterNeighbor(bestLoc, trace) 

    # update the best we've seen so far based on this new maximum 
    if bestSeen is None or problem.get(neighbor) > problem.get(bestSeen): 
     bestSeen = neighbor 
     if not trace is None: trace.setBestSeen(bestSeen) 

    # return when we know we've found a peak 
    if neighbor == bestLoc and problem.get(bestLoc) >= problem.get(bestSeen): 
     if not trace is None: trace.foundPeak(bestLoc) 
     return bestLoc 

    # figure out which subproblem contains the largest number we've seen so 
    # far, and recurse, alternating between splitting on rows and splitting 
    # on columns 
    sub = problem.getSubproblemContaining(subproblems, bestSeen) 
    newBest = sub.getLocationInSelf(problem, bestSeen) 
    if not trace is None: trace.setProblemDimensions(sub) 
    result = algorithm4(sub, newBest, not rowSplit, trace) 
    return problem.getLocationInSelf(sub, result) 

#Helper Method 
def crossProduct(list1, list2): 
    """ 
    Returns all pairs with one item from the first list and one item from 
    the second list. (Cartesian product of the two lists.) 

    The code is equivalent to the following list comprehension: 
     return [(a, b) for a in list1 for b in list2] 
    but for easier reading and analysis, we have included more explicit code. 
    """ 

    answer = [] 
    for a in list1: 
     for b in list2: 
      answer.append ((a, b)) 
    return answer 
+0

Perché votare per chiudere la domanda? –

+1

Solo un picco casuale o tutti i picchi? – Andrey

+1

Solo un picco casuale –

risposta

4
  1. Supponiamo che la larghezza della matrice è più grande di altezza, altrimenti raggruppati in un'altra direzione.
  2. Dividere l'array in tre parti: colonna centrale, lato sinistro e lato destro.
  3. Passare attraverso la colonna centrale e due colonne adiacenti e cercare il massimo.
    • Se è nella colonna centrale - questo è il nostro picco
    • Se è nella parte sinistra, eseguire questo algoritmo su sottoarray left_side + central_column
    • Se è nella parte destra, eseguire questo algoritmo su sottoarray right_side + central_column

Perché funziona:

Per i casi in cui l'elemento massimo è nella colonna centrale - evidente. Se non lo è, possiamo passare da quel massimo ad elementi crescenti e sicuramente non attraverseremo la fila centrale, quindi un picco sicuramente esisterà nella metà corrispondente.

Perchè questo è O (n):

passo # 3 richiede meno o uguale a max_dimension iterazioni e max_dimension almeno metà su ogni due passi algoritmo. Questo dà n+n/2+n/4+... che è O(n). Dettaglio importante: ci siamo divisi per la direzione massima. Per gli array quadrati ciò significa che le direzioni divergenti si alterneranno. Questa è una differenza rispetto all'ultimo tentativo nel PDF a cui ti sei collegato.

Una nota: non sono sicuro che se corrisponde esattamente all'algoritmo nel codice che hai fornito, potrebbe essere o meno un approccio diverso.

+0

Questo è O (nlogn) nel seguente caso. Esiste una matrice NxN. Ad ogni chiamata ricorsiva sul tuo algoritmo, calcoli il valore massimo nella colonna. Questo succede per i tempi di logn. Quindi la complessità è O (nlogn). E penso che il tuo algoritmo sia simile a quello sulla diapositiva 6 (http://courses.csail.mit.edu/6.006/fall11/lectures/lecture1.pdf). E hanno analizzato questo per essere O (nlogn). –

+0

Ogni volta che la matrice diventa più piccola. Quindi non è 'n + n + n + ...' 'log (n)' volte. È 'n + n/2 + n/4 + ... <2n' – maxim1000

+0

Mi sembra di capire l'algoritmo ora e che sarà O (n). Ma puoi anche aggiungere una relazione di ricorrenza alla tua risposta. Ciò chiarirà ulteriormente che il tempo di esecuzione è O (n). –

1

Here is the working Java code che implementa l'algoritmo di @ maxim1000. Il seguente codice trova un picco nell'array 2D in tempo lineare.

import java.util.*; 

class Ideone{ 
    public static void main (String[] args) throws java.lang.Exception{ 
     new Ideone().run(); 
    } 
    int N , M ; 

    void run(){ 
     N = 1000; 
     M = 100; 

     // arr is a random NxM array 
     int[][] arr = randomArray(); 
     long start = System.currentTimeMillis(); 
//  for(int i=0; i<N; i++){ // TO print the array. 
//   System. out.println(Arrays.toString(arr[i])); 
//  } 
     System.out.println(findPeakLinearTime(arr)); 
     long end = System.currentTimeMillis(); 
     System.out.println("time taken : " + (end-start)); 
    } 

    int findPeakLinearTime(int[][] arr){ 
     int rows = arr.length; 
     int cols = arr[0].length; 
     return kthLinearColumn(arr, 0, cols-1, 0, rows-1); 
    } 

    // helper function that splits on the middle Column 
    int kthLinearColumn(int[][] arr, int loCol, int hiCol, int loRow, int hiRow){ 
     if(loCol==hiCol){ 
      int max = arr[loRow][loCol]; 
      int foundRow = loRow; 
      for(int row = loRow; row<=hiRow; row++){ 
       if(max < arr[row][loCol]){ 
        max = arr[row][loCol]; 
        foundRow = row; 
       } 
      } 
      if(!correctPeak(arr, foundRow, loCol)){ 
       System.out.println("THIS PEAK IS WRONG"); 
      } 
      return max; 
     } 
     int midCol = (loCol+hiCol)/2; 
     int max = arr[loRow][loCol]; 
     for(int row=loRow; row<=hiRow; row++){ 
      max = Math.max(max, arr[row][midCol]); 
     } 
     boolean centralMax = true; 
     boolean rightMax = false; 
     boolean leftMax = false; 

     if(midCol-1 >= 0){ 
      for(int row = loRow; row<=hiRow; row++){ 
       if(arr[row][midCol-1] > max){ 
        max = arr[row][midCol-1]; 
        centralMax = false; 
        leftMax = true; 
       } 
      } 
     } 

     if(midCol+1 < M){ 
      for(int row=loRow; row<=hiRow; row++){ 
       if(arr[row][midCol+1] > max){ 
        max = arr[row][midCol+1]; 
        centralMax = false; 
        leftMax = false; 
        rightMax = true; 
       } 
      } 
     } 

     if(centralMax) return max; 
     if(rightMax) return kthLinearRow(arr, midCol+1, hiCol, loRow, hiRow); 
     if(leftMax) return kthLinearRow(arr, loCol, midCol-1, loRow, hiRow); 
     throw new RuntimeException("INCORRECT CODE"); 
    } 

    // helper function that splits on the middle 
    int kthLinearRow(int[][] arr, int loCol, int hiCol, int loRow, int hiRow){ 
     if(loRow==hiRow){ 
      int ans = arr[loCol][loRow]; 
      int foundCol = loCol; 
      for(int col=loCol; col<=hiCol; col++){ 
       if(arr[loRow][col] > ans){ 
        ans = arr[loRow][col]; 
        foundCol = col; 
       } 
      } 
      if(!correctPeak(arr, loRow, foundCol)){ 
       System.out.println("THIS PEAK IS WRONG"); 
      } 
      return ans; 
     } 
     boolean centralMax = true; 
     boolean upperMax = false; 
     boolean lowerMax = false; 

     int midRow = (loRow+hiRow)/2; 
     int max = arr[midRow][loCol]; 

     for(int col=loCol; col<=hiCol; col++){ 
      max = Math.max(max, arr[midRow][col]); 
     } 

     if(midRow-1>=0){ 
      for(int col=loCol; col<=hiCol; col++){ 
       if(arr[midRow-1][col] > max){ 
        max = arr[midRow-1][col]; 
        upperMax = true; 
        centralMax = false; 
       } 
      } 
     } 

     if(midRow+1<N){ 
      for(int col=loCol; col<=hiCol; col++){ 
       if(arr[midRow+1][col] > max){ 
        max = arr[midRow+1][col]; 
        lowerMax = true; 
        centralMax = false; 
        upperMax = false; 
       } 
      } 
     } 

     if(centralMax) return max; 
     if(lowerMax) return kthLinearColumn(arr, loCol, hiCol, midRow+1, hiRow); 
     if(upperMax) return kthLinearColumn(arr, loCol, hiCol, loRow, midRow-1); 
     throw new RuntimeException("Incorrect code"); 
    } 

    int[][] randomArray(){ 
     int[][] arr = new int[N][M]; 
     for(int i=0; i<N; i++) 
      for(int j=0; j<M; j++) 
       arr[i][j] = (int)(Math.random()*1000000000); 
     return arr; 
    } 

    boolean correctPeak(int[][] arr, int row, int col){//Function that checks if arr[row][col] is a peak or not 
     if(row-1>=0 && arr[row-1][col]>arr[row][col]) return false; 
     if(row+1<N && arr[row+1][col]>arr[row][col]) return false; 
     if(col-1>=0 && arr[row][col-1]>arr[row][col]) return false; 
     if(col+1<M && arr[row][col+1]>arr[row][col]) return false; 
     return true; 
    } 
}