2015-08-04 13 views
11

ho due dataframes chiamati sinistra e destra.sostituendo i valori nulli con 0. scintilla dataframe left outer join

scala> left.printSchema 
root 
|-- user_uid: double (nullable = true) 
|-- labelVal: double (nullable = true) 
|-- probability_score: double (nullable = true) 

scala> right.printSchema 
root 
|-- user_uid: double (nullable = false) 
|-- real_labelVal: double (nullable = false) 

Quindi, mi unisco a loro per ottenere il Dataframe unito. È un . Chiunque sia interessato alla funzione natjoin può trovarlo qui.

https://gist.github.com/anonymous/f02bd79528ac75f57ae8

scala> val joinedData = natjoin(predictionDataFrame, labeledObservedDataFrame, "left_outer") 

scala> joinedData.printSchema 
|-- user_uid: double (nullable = true) 
|-- labelVal: double (nullable = true) 
|-- probability_score: double (nullable = true) 
|-- real_labelVal: double (nullable = false) 

Poiché si tratta di un join esterno sinistro, colonna real_labelVal trovi null quando user_uid non è presente in ragione.

scala> val realLabelVal = joinedData.select("real_labelval").distinct.collect 
realLabelVal: Array[org.apache.spark.sql.Row] = Array([0.0], [null]) 

voglio sostituire i valori nulli nella colonna realLabelVal con 1,0.

Attualmente faccio la seguente:

  1. trovo l'indice della colonna di real_labelval e utilizzare l'API spark.sql.Row per impostare i valori nulli a 1,0. (Questo mi dà un RDD [riga])
  2. Quindi applico lo schema del dataframe unito per ottenere il dataframe pulito.

Il codice è il seguente:

val real_labelval_index = 3 
def replaceNull(row: Row) = { 
    val rowArray = row.toSeq.toArray 
    rowArray(real_labelval_index) = 1.0 
    Row.fromSeq(rowArray) 
} 

val cleanRowRDD = joinedData.map(row => if (row.isNullAt(real_labelval_index)) replaceNull(row) else row) 
val cleanJoined = sqlContext.createDataFrame(cleanRowRdd, joinedData.schema) 

C'è un modo elegante o efficiente per fare questo?

Goolging non ha aiutato molto. Grazie in anticipo.

+0

cosa fa il basamento NAT per in natjoin? –

+1

@JosiahYoder nat sta per Natural Join. –

risposta

23

Hai provato a usare na

joinedData.na.fill(1.0, Seq("real_labelval")) 
+0

Grazie per la rapida risposta. Il problema è che usiamo la distribuzione cloudera e il cluster ha la scintilla 1.3.0. Le funzioni di riempimento sono state introdotte nella scintilla 1.4, penso. Lo accetto come risposta. –

+0

Devo importare qualcosa per usare na? Grazie –

+0

@GavinNiu No, 'na' è un metodo direttamente su' DataFrame' –